'how to apply ModelCheckpoint in a custom training loop for tensorflow 2.0?

we can set tf.keras.callbacks.ModelCheckpoint(), then pass a callbacks argument to fit() method to save the best modelcheckpoint, but how to make the same thing in a custom training loop?



Solution 1:[1]

  1. You can use CallbackList:

     cp_callback = ModelCheckpoint(filepath=checkpoint_path,
                                   monitor='val_loss',
                                   save_weights_only=True,
                                   save_best_only=True,
                                   mode='auto',
                                   save_freq='epoch',
                                   verbose=1)
     callbacks = CallbackList(_callbacks, add_history=True, model=model)
     logs = {}
     callbacks.on_train_begin(logs=logs)
    
     optimizer = Adam(lr=init_lr, beta_1=0.9, beta_2=0.999, clipvalue=1.0)
    
     loss_fn = BinaryCrossentropy(from_logits=False)
     train_loss_tracker = tf.keras.metrics.Mean()
     val_loss_tracker = tf.keras.metrics.Mean()
     train_acc_metric = tf.keras.metrics.BinaryAccuracy()
     val_acc_metric = tf.keras.metrics.BinaryAccuracy()
    
     @tf.function(experimental_relax_shapes=True)
     def train_step(x, y):
         with tf.GradientTape() as tape:
             logits = self.net(x, training=True)
             loss_value = loss_fn(y, logits)
         grads = tape.gradient(loss_value, self.net.trainable_weights)
         optimizer.apply_gradients(zip(grads, self.net.trainable_weights))
         train_loss_tracker.update_state(loss_value)
         train_acc_metric.update_state(y, logits)
         return {"train_loss": train_loss_tracker.result(), "train_accuracy": train_acc_metric.result()}
    
     @tf.function(experimental_relax_shapes=True)
     def val_step(x, y):
         val_logits = self.net(x, training=False)
         val_loss = loss_fn(y, val_logits)
         val_loss_tracker.update_state(val_loss)
         val_acc_metric.update_state(y, val_logits)
         return {"val_loss": val_loss_tracker.result(), "val_accuracy": val_acc_metric.result()}
    
     for epoch in range(args.max_epoch):
    
         print("\nStart of epoch %d" % (epoch,))
         start_time = time.time()
    
         for step, (x_batch_train, y_batch_train) in enumerate(train_gen):
    
             callbacks.on_batch_begin(step, logs=logs)
             callbacks.on_train_batch_begin(step, logs=logs)
             train_dict = train_step(x_batch_train, np.expand_dims(y_batch_train, axis=0))
    
             logs["train_loss"] = train_dict["train_loss"]
    
             callbacks.on_train_batch_end(step, logs=logs)
             callbacks.on_batch_end(step, logs=logs)
             if step % 100 == 0:
                 print("Training loss (for one batch) at step %d: %.4f" % (step, float(logs["train_loss"])))
    
         train_acc = train_acc_metric.result()
         print("Training acc over epoch: %.4f" % (float(train_acc),))
         train_acc_metric.reset_states()
         train_loss_tracker.reset_states()
    
         for step, (x_batch_val, y_batch_val) in enumerate(val_gen):
             callbacks.on_batch_begin(step, logs=logs)
             callbacks.on_test_batch_begin(step, logs=logs)
    
             val_step(x_batch_val, np.expand_dims(y_batch_val, axis=0))
    
             callbacks.on_test_batch_end(step, logs=logs)
             callbacks.on_batch_end(step, logs=logs)
    
         logs["val_loss"] = val_loss_tracker.result()
         val_acc = val_acc_metric.result()
         print("Validation acc: %.4f" % (float(val_acc),))
         print("Time taken: %.2fs" % (time.time() - start_time)
    
         val_acc_metric.reset_states()
         val_loss_tracker.reset_states()
         callbacks.on_epoch_end(epoch, logs=logs)
    
     callbacks.on_train_end(logs=logs)
    

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 user3606057