'how to apply ModelCheckpoint in a custom training loop for tensorflow 2.0?
we can set tf.keras.callbacks.ModelCheckpoint()
, then pass a callbacks
argument to fit()
method to save the best modelcheckpoint, but how to make the same thing in a custom training loop?
Solution 1:[1]
You can use CallbackList:
cp_callback = ModelCheckpoint(filepath=checkpoint_path, monitor='val_loss', save_weights_only=True, save_best_only=True, mode='auto', save_freq='epoch', verbose=1) callbacks = CallbackList(_callbacks, add_history=True, model=model) logs = {} callbacks.on_train_begin(logs=logs) optimizer = Adam(lr=init_lr, beta_1=0.9, beta_2=0.999, clipvalue=1.0) loss_fn = BinaryCrossentropy(from_logits=False) train_loss_tracker = tf.keras.metrics.Mean() val_loss_tracker = tf.keras.metrics.Mean() train_acc_metric = tf.keras.metrics.BinaryAccuracy() val_acc_metric = tf.keras.metrics.BinaryAccuracy() @tf.function(experimental_relax_shapes=True) def train_step(x, y): with tf.GradientTape() as tape: logits = self.net(x, training=True) loss_value = loss_fn(y, logits) grads = tape.gradient(loss_value, self.net.trainable_weights) optimizer.apply_gradients(zip(grads, self.net.trainable_weights)) train_loss_tracker.update_state(loss_value) train_acc_metric.update_state(y, logits) return {"train_loss": train_loss_tracker.result(), "train_accuracy": train_acc_metric.result()} @tf.function(experimental_relax_shapes=True) def val_step(x, y): val_logits = self.net(x, training=False) val_loss = loss_fn(y, val_logits) val_loss_tracker.update_state(val_loss) val_acc_metric.update_state(y, val_logits) return {"val_loss": val_loss_tracker.result(), "val_accuracy": val_acc_metric.result()} for epoch in range(args.max_epoch): print("\nStart of epoch %d" % (epoch,)) start_time = time.time() for step, (x_batch_train, y_batch_train) in enumerate(train_gen): callbacks.on_batch_begin(step, logs=logs) callbacks.on_train_batch_begin(step, logs=logs) train_dict = train_step(x_batch_train, np.expand_dims(y_batch_train, axis=0)) logs["train_loss"] = train_dict["train_loss"] callbacks.on_train_batch_end(step, logs=logs) callbacks.on_batch_end(step, logs=logs) if step % 100 == 0: print("Training loss (for one batch) at step %d: %.4f" % (step, float(logs["train_loss"]))) train_acc = train_acc_metric.result() print("Training acc over epoch: %.4f" % (float(train_acc),)) train_acc_metric.reset_states() train_loss_tracker.reset_states() for step, (x_batch_val, y_batch_val) in enumerate(val_gen): callbacks.on_batch_begin(step, logs=logs) callbacks.on_test_batch_begin(step, logs=logs) val_step(x_batch_val, np.expand_dims(y_batch_val, axis=0)) callbacks.on_test_batch_end(step, logs=logs) callbacks.on_batch_end(step, logs=logs) logs["val_loss"] = val_loss_tracker.result() val_acc = val_acc_metric.result() print("Validation acc: %.4f" % (float(val_acc),)) print("Time taken: %.2fs" % (time.time() - start_time) val_acc_metric.reset_states() val_loss_tracker.reset_states() callbacks.on_epoch_end(epoch, logs=logs) callbacks.on_train_end(logs=logs)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | user3606057 |