'making GRU/LSTM states trainable in Tensorflow/Keras and add some random noise
I train the following model based on GRU, note that I am passing the argument stateful=True
to the GRU builder.
class LearningToSurpriseModel(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, rnn_units):
super().__init__(self)
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(rnn_units,
stateful=True,
return_sequences=True,
return_state=True,
reset_after=True
)
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, inputs, states=None, return_state=False, training=False):
x = inputs
x = self.embedding(x, training=training)
if states is None:
states = self.gru.get_initial_state(x)
x, states = self.gru(x, initial_state=states, training=training)
x = self.dense(x, training=training)
if return_state:
return x, states
else:
return x
@tf.function
def train_step(self, inputs):
[defining here my training step]
I instantiate my model
model = LearningToSurpriseModel(
vocab_size=len(ids_from_chars.get_vocabulary()),
embedding_dim=embedding_dim,
rnn_units=rnn_units
)
[compile and do stuff] the custom callback below reset states manually at the end of each epoch.
gru_layer = model.layers[1]
class CustomCallback(tf.keras.callbacks.Callback):
def __init__(self, gru_layer):
self.gru_layer = gru_layer
def on_epoch_end(self, epoch, logs=None):
self.gru_layer.reset_states()
model.fit(train_dataset, validation_data=validation_dataset, \
epochs=EPOCHS, callbacks = [EarlyS, CustomCallback(gru_layer)], verbose=1)
States will be reset to zero. I would like to follow ideas in https://r2rt.com/non-zero-initial-states-for-recurrent-neural-networks.html to make states trainable. Implementation in this post seems based on tensorflow, and overwrites native functions, maybe there is a more elegant way in Keras.
(1) how do I make states trainable ?
(2) how do I combine trainable states and random initialization ?
Solution 1:[1]
Based on the very good answer above, the full code for solving the case
class CustomGRULayer(tf.keras.layers.Layer):
def __init__(self, rnn_units, batch_size):
super(CustomGRULayer, self).__init__()
self.rnn_units = rnn_units
self.batch_size = batch_size
self.gru = tf.keras.layers.GRU(self.rnn_units,
stateful=True,
return_sequences=True,
return_state=True,
reset_after=True,
)
self.w=None
def build(self, input_shape):
w_init = tf.random_normal_initializer(mean=0.0, stddev=0.2)
self.w = tf.Variable(
initial_value=w_init(shape=(self.batch_size, self.rnn_units),
dtype='float32'), trainable=True)
def call(self, inputs):
return self.gru(inputs, initial_state = self.w)
class LearningToSurpriseModel(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, rnn_units, batch_size):
super().__init__(self)
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru_layer = CustomGRULayer(rnn_units = rnn_units, batch_size = batch_size)
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, inputs, states=None, return_state=False, training=False):
x = inputs
x = self.embedding(x, training=training)
if states is None:
states = self.gru_layer.gru.get_initial_state(x)
x, states = self.gru_layer.gru(x, initial_state=states, training=training)
x = self.dense(x, training=training)
if return_state:
return x, states
else:
return x
model = LearningToSurpriseModel(
vocab_size=len(ids_from_chars.get_vocabulary()),
embedding_dim=embedding_dim,
rnn_units=rnn_units,
batch_size=BATCH_SIZE
)
model.compile(optimizer='adam', loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
tf.keras.metrics.SparseCategoricalAccuracy()]
)
EarlyS = EarlyStopping(monitor = 'val_loss', mode = 'min', restore_best_weights=True, patience=10, verbose = 1)
# defining a custom callback for resetting states at the end of period only
gru_layer = model.layers[1]
class CustomCallback(tf.keras.callbacks.Callback):
def __init__(self, gru_layer):
self.gru_layer = gru_layer
def on_epoch_end(self, epoch, logs=None):
self.gru_layer.gru.reset_states(self.gru_layer.w)
model.fit(train_dataset, validation_data=validation_dataset, epochs=EPOCHS, callbacks = [EarlyS, CustomCallback(gru_layer)], verbose=1)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | kiriloff |