'Tensorflow LSTM/GRU reset states once per epoch and not for each new batch

I train the following model based on GRU, note that I am passing the argument stateful=True to the GRU builder.

class LearningToSurpriseModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_units):
    super().__init__(self)
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(rnn_units,
                                   stateful=True,
                                   return_sequences=True,
                                   return_state=True,
                                   reset_after=True  
                                   )
    self.dense = tf.keras.layers.Dense(vocab_size)

  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs
    x = self.embedding(x, training=training)
    if states is None:
      states = self.gru.get_initial_state(x)
    x, states = self.gru(x, initial_state=states, training=training)
    x = self.dense(x, training=training)

    if return_state:
      return x, states
    else:
      return x

  @tf.function
  def train_step(self, inputs):
    [defining here my training step]

I instantiate my model

model = LearningToSurpriseModel(
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units
    )

[compile and do stuff]

and train for EPOCHS epochs

for i in range(EPOCHS):
  model.fit(train_dataset, validation_data=validation_dataset, epochs=1, callbacks = [EarlyS], verbose=1)
  model.reset_states()

What is the behavior of this code regarding GRU states : are states updated for each new batch of data or only for each new epoch ? The desired behavior is a reset for each new epoch only. If not done, how to implement this ?

EDIT

Tensorflow implements the reset_states function for Models as

  def reset_states(self):
    for layer in self.layers:
      if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
        layer.reset_states()

Does it means (contrary to what doc otherwise seems to imply) states can be reset only if stateful=False ? It is what I infer from the condition on getattr(layer, 'stateful', False).



Solution 1:[1]

You can try resetting the states in a custom Callback:

model = LearningToSurpriseModel(
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units
    )

gru_layer = model.layers[1]

class CustomCallback(tf.keras.callbacks.Callback):
   def __init__(self, gru_layer):
        self.gru_layer = gru_layer
   def on_epoch_end(self, epoch, logs=None):
        self.gru_layer.reset_states()

model.fit(train_dataset, validation_data=validation_dataset, epochs=1, callbacks = [EarlyS, CustomCallback(gru_layer)], verbose=1)

Also, see this post regarding when to reset the GRU states.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1