'Specifying the batch size when subclassing keras.model
I implement a model including a LSTM layer by subclassing the keras.Model. The following is the code I used.
import tensorflow as tf
from tensorflow import keras
class SequencePredictor(keras.Model):
    def __init__(self, cell_size):
        super(SequencePredictor, self).__init__(self)
        self._mask = keras.layers.Masking(
            mask_value=0.0)
        self._lstm = keras.layers.LSTM(cell_size, return_sequences=True,
                stateful=True)
        self._dense = tf.keras.layers.Dense(1)
    def call(self, inputs, training=False):
        out = self._mask(inputs)
        out = self._lstm(out)
        return self._dense(out)
sequence_predictor = SequencePredictor(cell_size=10)
train_dataset = DataQueue(5000, 500).dataset
sequence_predictor.compile(
        loss="mean_squared_error",
        optimizer=keras.optimizers.Adam(lr=0.05),
        metrics=['mse'])
sequence_predictor.fit(train_dataset, epochs=50)
The above code doesn't work with the following error message.
  ValueError: If a RNN is stateful, it needs to know its batch size. Specify the batch size of your input tensors: 
    - If using a Sequential model, specify the batch size by passing a `batch_input_shape` argument to your first layer.
    - If using the functional API, specify the batch size by passing a `batch_shape` argument to your Input layer.
The problem is that the above code is neither a sequential nor a functional model. I tried to specify batch_input_shape in the first Masking layer, but it doesn't work.
How to resolve the above error? I am currently using tensorflow 2.0 rc0.
Solution 1:[1]
Build your LSTM layer with the desired input shape after instanciating it.
self._lstm = keras.layers.LSTM(cell_size, return_sequences=True, stateful=True)
self._lstm.build((input_shape))
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source | 
|---|---|
| Solution 1 | Gustavo de Rosa | 
