'Keras - no good way to stop and resume training?

After a lot of research, it seems like there is no good way to properly stop and resume training using a Tensorflow 2 / Keras model. This is true whether you are using model.fit() or using a custom training loop.

There seem to be 2 supported ways to save a model while training:

  1. Save just the weights of the model, using model.save_weights() or save_weights_only=True with tf.keras.callbacks.ModelCheckpoint. This seems to be preferred by most of the examples I've seen, however it has a number of major issues:

    • The optimizer state is not saved, meaning training resumption will not be correct.
    • Learning rate schedule is reset - this can be catastrophic for some models.
    • Tensorboard logs go back to step 0 - making logging essentually useless unless complex workarounds are implemented.
  2. Save the entire model, optimizer, etc. using model.save() or save_weights_only=False. The optimizer state is saved (good) but the following issues remain:

    • Tensorboard logs still go back to step 0
    • Learning rate schedule is still reset (!!!)
    • It is impossible to use custom metrics.
    • This doesn't work at all when using a custom training loop - custom training loops use a non-compiled model, and saving/loading a non-compiled model doesn't seem to be supported.

The best workaround I've found is to use a custom training loop, manually saving the step. This fixes the tensorboard logging, and the learning rate schedule can be fixed by doing something like keras.backend.set_value(model.optimizer.iterations, step). However, since a full model save is off the table, the optimizer state is not preserved. I can see no way to save the state of the optimizer independently, at least without a lot of work. And messing with the LR schedule as I've done feels messy as well.

Am I missing something? How are people out there saving/resuming using this API?



Solution 1:[1]

You're right, there isn't builtin support for resumability - which is exactly what motivated me to create DeepTrain. It's like Pytorch Lightning (better and worse in different regards) for TensorFlow/Keras.

Why another library? Don't we have enough? You have nothing like this; if there was, I'd not build it. DeepTrain's tailored for the "babysitting approach" to training: train fewer models, but train them thoroughly. Closely monitor each stage to diagnose what's wrong and how to fix.

Inspiration came from my own use; I'd see "validation spikes" throughout a long epoch, and couldn't afford to pause as it'd restart the epoch or otherwise disrupt the train loop. And forget knowing which batch you were fitting, or how many remain.

How's it compare to Pytorch Lightning? Superior resumability and introspection, along unique train debug utilities - but Lightning fares better in other regards. I have a comprehensive list comparison in working, will post within a week.

Pytorch support coming? Maybe. If I convince the Lightning dev team to make up for its shortcomings relative to DeepTrain, then not - otherwise probably. In the meantime, you can explore the gallery of Examples.


Minimal example:

from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from deeptrain import TrainGenerator, DataGenerator

ipt = Input((16,))
out = Dense(10, 'softmax')(ipt)
model = Model(ipt, out)
model.compile('adam', 'categorical_crossentropy')

dg  = DataGenerator(data_path="data/train", labels_path="data/train/labels.npy")
vdg = DataGenerator(data_path="data/val",   labels_path="data/val/labels.npy")
tg  = TrainGenerator(model, dg, vdg, epochs=3, logs_dir="logs/")

tg.train()

You can KeyboardInterrupt at any time, inspect the model, train state, data generator - and resume.

Solution 2:[2]

tf.keras.callbacks.experimental.BackupAndRestore API for resuming training from interruptions has been added for tensorflow>=2.3. It works great in my experience.

Reference: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/experimental/BackupAndRestore

Solution 3:[3]

tf.keras.callbacks.BackupAndRestore can take care of this.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 OverLordGoldDragon
Solution 2 yanp
Solution 3 mehran