'How can I use Tensorflow.Checkpoint to recover a previously trained net

I'm trying to understand how to recover a saved/checkpointed net using tensorflow.train.Checkpoint.restore.

I'm using code that's strongly based on Google's Colab tutorial for creating a pix2pix GAN. Below, I've excerpted the key portion, which just attempts to instantiate a new net, then to fill it with weights from a previous net that was saved and checkpointed.

I'm assigning a unique(ish) id number to a particular instantiation of a net by summing all the weights of the net. I compare these id numbers both at the creation of the net, and after I've attempted to recover the checkpointed net

def main(opt):

    # Initialize pix2pix GAN using arguments input from command line
    p2p = Pix2Pix(vars(opt))
    print(opt)

    # print sum of initial weights for net
    print("Init Model Weights:", 
           sum([x.numpy().sum() for x in p2p.generator.weights]))

    # Create or read from model checkpoints
    checkpoint = tf.train.Checkpoint(generator_optimizer=p2p.generator_optimizer,
                                     discriminator_optimizer=p2p.discriminator_optimizer,
                                     generator=p2p.generator,
                                     discriminator=p2p.discriminator)
    
    # print sum of weights from checkpoint, to ensure it has access 
    # to relevant regions of p2p
    print("Checkpoint Weights:", 
           sum([x.numpy().sum() for x in checkpoint.generator.weights]))

    # Recover Checkpointed net
    checkpoint.restore(tf.train.latest_checkpoint(opt.weights)).expect_partial()

    # print sum of weights for p2p & checkpoint after attempting to restore saved net 
    print("Restore Model Weights:", 
           sum([x.numpy().sum() for x in p2p.generator.weights]))
    print("Restored Checkpoint Weights:", 
           sum([x.numpy().sum() for x in checkpoint.generator.weights]))
    print("Done.")

if __name__ == '__main__':
    opt = parse_opt()
    main(opt)

The output I got when I ran this code was as follows:

Namespace(channels='1', data='data', img_size=256, output='output', weights='weights/ckpt-40.data-00000-of-00001')
## These are the input arguments, the images have only 1 channel (they're gray scale)
## The directory with data is ./data, the images are 265x256
## The output directory is ./output
## The checkpointed net is stored in ./weights/ckpt-40.data-00000-of-00001


## Sums of nets' weights
Init Model Weights: 11047.206374436617
Checkpoint Weights: 11047.206374436617
Restore Model Weights: 11047.206374436617
Restored Checkpoint Weights: 11047.206374436617

Done.

There is no change in the sum of the net's weights before and after recovering the checkpointed version, although p2p and checkpoint do seem to have access to the same locations in memory.

Why am I not recovering the saved net?



Solution 1:[1]

The problem arose because tf.Checkpoint.restore needs the directory in which the checkpointed net is stored, not the specific file (or, what I took to be the specific file - ./weights/ckpt-40.data-00000-of-00001)

When it is not given a valid directory, it silently proceeds to the next line of code, without updating the net or throwing an error. The fix was to give it the directory with the relevant checkpoint files, rather than just the file I believed to be relevant.

Solution 2:[2]

My alternative way is using callback and restore, you may name the layer for checkpoints they determine.

Example:

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: DataSet
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
DATA = adding_array_DATA(DATA, action, reward, gamescores, step)

dataset = tf.data.Dataset.from_tensor_slices((tf.constant(DATA, dtype=tf.float32),tf.constant(np.reshape(0, (1, 1, 1, 1)))))
batched_features = dataset

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Initialize
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model = tf.keras.models.Sequential([
    tf.keras.layers.InputLayer(input_shape=(1200, 1)),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128, return_sequences=True, return_state=False)),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128)),
])
        
model.add(layers.Flatten())
model.add(layers.Dense(64))
model.add(layers.Dense(2))
model.summary()

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Callback
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_dir, monitor='val_loss', 
                                verbose=0, save_best_only=True, mode='min' )
                                
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Optimizer
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
optimizer = tf.keras.optimizers.Nadam(
    learning_rate=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-07,
    name='Nadam'
) # 0.00001

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Loss Fn
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""                               
# 1
lossfn = tf.keras.losses.MeanSquaredLogarithmicError(reduction=tf.keras.losses.Reduction.AUTO, name='mean_squared_logarithmic_error')
# 2
# lossfn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Summary
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model.compile(optimizer=optimizer, loss=lossfn, metrics=['accuracy'])

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Training
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
history = model.fit(batched_features, epochs=1 ,validation_data=(batched_features), callbacks=[cp_callback]) # epochs=500 # , callbacks=[cp_callback, tb_callback]

checkpoint = tf.train.Checkpoint(model)
checkpoint.restore(checkpoint_dir)

input('...')

Output:

2022-03-08 10:33:06.965274: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8100
1/1 [==============================] - ETA: 0s - **loss: 0.0154** - accuracy: 0.0000e+002022-03-08 10:33:16.175845: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
1/1 [==============================] - 31s 31s/step - **loss: 0.0154** - accuracy: 0.0000e+00 - val_loss: 0.0074 - val_accuracy: 0.0000e+00
...

Result

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 user1245262
Solution 2 Martijn Pieters