'How to ignore regularization losses from sub-layer of tf.keras.Model?
I want to distill knowledge from a teacher student to a student one, so I implemented a class named OCCSE
that inherits from tf.keras.Model
and accepts both teacher and a student as sub-layers (code below, slightly simplified for the example).
It works for teacher networks without any regularization.
For the other networks, I noticed that the regularization losses from the teacher are being added to the list of losses of the wrapping model, even though I marked teacher.trainable = False
.
How can I prevent this from happening? I don't want regularization of the teacher in the optimization process, as it's not important and it's a tensor of rank 0 and it cannot be added directly to non-reduced loss
(I have multiple GPUs => used reduction='none'
in BinaryCrossentropy
=> loss is a tensor of shape [batch_size]).
class OCCSE(Model):
def __init__(
self,
inference_net: tf.keras.Model,
support_net: Optional[tf.keras.Model],
...
):
super().__init__(**kwargs)
self.inference_net = inference_net
self.support_net = support_net
...
def train_step(self, data):
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
masks = ...
label_mask = ...
with tf.GradientTape() as tape:
outputs = self.inference_net(x, training=True, with_maps=True)
x_erased = x * (1 - masks)
ps = self.support_net(x_erased, training=False)
loss_support = self.loss_support(ps, y - label_mask)
loss = self.compiled_loss(y, p, sample_weight, regularization_losses=self.losses)
loss = loss + alpha * loss_support
loss = tf.nn.compute_average_loss(loss, global_batch_size=self.batch_size)
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients,
return self.compute_metrics(x, y, p, sample_weight)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|