'How can I use Tensorflow.Checkpoint to recover a previously trained net
I'm trying to understand how to recover a saved/checkpointed net using tensorflow.train.Checkpoint.restore
.
I'm using code that's strongly based on Google's Colab tutorial for creating a pix2pix GAN. Below, I've excerpted the key portion, which just attempts to instantiate a new net, then to fill it with weights from a previous net that was saved and checkpointed.
I'm assigning a unique(ish) id number to a particular instantiation of a net by summing all the weights of the net. I compare these id numbers both at the creation of the net, and after I've attempted to recover the checkpointed net
def main(opt):
# Initialize pix2pix GAN using arguments input from command line
p2p = Pix2Pix(vars(opt))
print(opt)
# print sum of initial weights for net
print("Init Model Weights:",
sum([x.numpy().sum() for x in p2p.generator.weights]))
# Create or read from model checkpoints
checkpoint = tf.train.Checkpoint(generator_optimizer=p2p.generator_optimizer,
discriminator_optimizer=p2p.discriminator_optimizer,
generator=p2p.generator,
discriminator=p2p.discriminator)
# print sum of weights from checkpoint, to ensure it has access
# to relevant regions of p2p
print("Checkpoint Weights:",
sum([x.numpy().sum() for x in checkpoint.generator.weights]))
# Recover Checkpointed net
checkpoint.restore(tf.train.latest_checkpoint(opt.weights)).expect_partial()
# print sum of weights for p2p & checkpoint after attempting to restore saved net
print("Restore Model Weights:",
sum([x.numpy().sum() for x in p2p.generator.weights]))
print("Restored Checkpoint Weights:",
sum([x.numpy().sum() for x in checkpoint.generator.weights]))
print("Done.")
if __name__ == '__main__':
opt = parse_opt()
main(opt)
The output I got when I ran this code was as follows:
Namespace(channels='1', data='data', img_size=256, output='output', weights='weights/ckpt-40.data-00000-of-00001')
## These are the input arguments, the images have only 1 channel (they're gray scale)
## The directory with data is ./data, the images are 265x256
## The output directory is ./output
## The checkpointed net is stored in ./weights/ckpt-40.data-00000-of-00001
## Sums of nets' weights
Init Model Weights: 11047.206374436617
Checkpoint Weights: 11047.206374436617
Restore Model Weights: 11047.206374436617
Restored Checkpoint Weights: 11047.206374436617
Done.
There is no change in the sum of the net's weights before and after recovering the checkpointed version, although p2p
and checkpoint
do seem to have access to the same locations in memory.
Why am I not recovering the saved net?
Solution 1:[1]
The problem arose because tf.Checkpoint.restore needs the directory in which the checkpointed net is stored, not the specific file (or, what I took to be the specific file - ./weights/ckpt-40.data-00000-of-00001)
When it is not given a valid directory, it silently proceeds to the next line of code, without updating the net or throwing an error. The fix was to give it the directory with the relevant checkpoint files, rather than just the file I believed to be relevant.
Solution 2:[2]
My alternative way is using callback and restore, you may name the layer for checkpoints they determine.
Example:
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: DataSet
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
DATA = adding_array_DATA(DATA, action, reward, gamescores, step)
dataset = tf.data.Dataset.from_tensor_slices((tf.constant(DATA, dtype=tf.float32),tf.constant(np.reshape(0, (1, 1, 1, 1)))))
batched_features = dataset
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Initialize
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(1200, 1)),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128, return_sequences=True, return_state=False)),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128)),
])
model.add(layers.Flatten())
model.add(layers.Dense(64))
model.add(layers.Dense(2))
model.summary()
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Callback
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_dir, monitor='val_loss',
verbose=0, save_best_only=True, mode='min' )
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Optimizer
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
optimizer = tf.keras.optimizers.Nadam(
learning_rate=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-07,
name='Nadam'
) # 0.00001
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Loss Fn
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
# 1
lossfn = tf.keras.losses.MeanSquaredLogarithmicError(reduction=tf.keras.losses.Reduction.AUTO, name='mean_squared_logarithmic_error')
# 2
# lossfn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Summary
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model.compile(optimizer=optimizer, loss=lossfn, metrics=['accuracy'])
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Training
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
history = model.fit(batched_features, epochs=1 ,validation_data=(batched_features), callbacks=[cp_callback]) # epochs=500 # , callbacks=[cp_callback, tb_callback]
checkpoint = tf.train.Checkpoint(model)
checkpoint.restore(checkpoint_dir)
input('...')
Output:
2022-03-08 10:33:06.965274: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8100
1/1 [==============================] - ETA: 0s - **loss: 0.0154** - accuracy: 0.0000e+002022-03-08 10:33:16.175845: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
1/1 [==============================] - 31s 31s/step - **loss: 0.0154** - accuracy: 0.0000e+00 - val_loss: 0.0074 - val_accuracy: 0.0000e+00
...
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | user1245262 |
Solution 2 | Martijn Pieters |