'making GRU/LSTM states trainable in Tensorflow/Keras and add some random noise

I train the following model based on GRU, note that I am passing the argument stateful=True to the GRU builder.

class LearningToSurpriseModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_units):
    super().__init__(self)
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(rnn_units,
                                   stateful=True,
                                   return_sequences=True,
                                   return_state=True,
                                   reset_after=True  
                                   )
    self.dense = tf.keras.layers.Dense(vocab_size)

  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs
    x = self.embedding(x, training=training)
    if states is None:
      states = self.gru.get_initial_state(x)
    x, states = self.gru(x, initial_state=states, training=training)
    x = self.dense(x, training=training)

    if return_state:
      return x, states
    else:
      return x

  @tf.function
  def train_step(self, inputs):
    [defining here my training step]

I instantiate my model

model = LearningToSurpriseModel(
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units
    )

[compile and do stuff] the custom callback below reset states manually at the end of each epoch.

gru_layer = model.layers[1]

class CustomCallback(tf.keras.callbacks.Callback):
   def __init__(self, gru_layer):
        self.gru_layer = gru_layer
   def on_epoch_end(self, epoch, logs=None):
        self.gru_layer.reset_states()
        
model.fit(train_dataset, validation_data=validation_dataset, \
    epochs=EPOCHS, callbacks = [EarlyS, CustomCallback(gru_layer)], verbose=1)

States will be reset to zero. I would like to follow ideas in https://r2rt.com/non-zero-initial-states-for-recurrent-neural-networks.html to make states trainable. Implementation in this post seems based on tensorflow, and overwrites native functions, maybe there is a more elegant way in Keras.

(1) how do I make states trainable ?

(2) how do I combine trainable states and random initialization ?



Solution 1:[1]

Based on the very good answer above, the full code for solving the case

class CustomGRULayer(tf.keras.layers.Layer):
  def __init__(self, rnn_units, batch_size):
    super(CustomGRULayer, self).__init__()
    self.rnn_units = rnn_units
    self.batch_size = batch_size
    self.gru = tf.keras.layers.GRU(self.rnn_units,
                                   stateful=True,
                                   return_sequences=True,
                                   return_state=True,
                                   reset_after=True,
                                   )
    self.w=None

  def build(self, input_shape):
    w_init = tf.random_normal_initializer(mean=0.0, stddev=0.2)
    self.w = tf.Variable(
        initial_value=w_init(shape=(self.batch_size, self.rnn_units),
                             dtype='float32'), trainable=True)
    
  def call(self, inputs): 
    return self.gru(inputs, initial_state = self.w)
  

class LearningToSurpriseModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_units, batch_size):
    super().__init__(self)

    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru_layer = CustomGRULayer(rnn_units = rnn_units, batch_size = batch_size)   
    self.dense = tf.keras.layers.Dense(vocab_size)
 
  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs
    x = self.embedding(x, training=training)

    if states is None:
      states = self.gru_layer.gru.get_initial_state(x)

    x, states = self.gru_layer.gru(x, initial_state=states, training=training)
    x = self.dense(x, training=training)
    if return_state:
      return x, states
    else:
      return x
    
model = LearningToSurpriseModel(
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units,
    batch_size=BATCH_SIZE
    )

model.compile(optimizer='adam', loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=[
                  tf.keras.metrics.SparseCategoricalAccuracy()]
              )

EarlyS = EarlyStopping(monitor = 'val_loss', mode = 'min', restore_best_weights=True, patience=10, verbose = 1)

# defining a custom callback for resetting states at the end of period only   
gru_layer = model.layers[1]

class CustomCallback(tf.keras.callbacks.Callback):
   def __init__(self, gru_layer):
        self.gru_layer = gru_layer
   def on_epoch_end(self, epoch, logs=None):
        self.gru_layer.gru.reset_states(self.gru_layer.w)
        
model.fit(train_dataset, validation_data=validation_dataset, epochs=EPOCHS, callbacks = [EarlyS, CustomCallback(gru_layer)], verbose=1)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 kiriloff