'How to ignore regularization losses from sub-layer of tf.keras.Model?

I want to distill knowledge from a teacher student to a student one, so I implemented a class named OCCSE that inherits from tf.keras.Model and accepts both teacher and a student as sub-layers (code below, slightly simplified for the example).

It works for teacher networks without any regularization. For the other networks, I noticed that the regularization losses from the teacher are being added to the list of losses of the wrapping model, even though I marked teacher.trainable = False.

How can I prevent this from happening? I don't want regularization of the teacher in the optimization process, as it's not important and it's a tensor of rank 0 and it cannot be added directly to non-reduced loss (I have multiple GPUs => used reduction='none' in BinaryCrossentropy => loss is a tensor of shape [batch_size]).

class OCCSE(Model):
  def __init__(
    self,
    inference_net: tf.keras.Model,
    support_net: Optional[tf.keras.Model],
    ...
  ):
    super().__init__(**kwargs)

    self.inference_net = inference_net
    self.support_net = support_net
    ...

  def train_step(self, data):
    x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
    masks = ...
    label_mask = ...

    with tf.GradientTape() as tape:
      outputs = self.inference_net(x, training=True, with_maps=True)
      x_erased = x * (1 - masks)
      ps = self.support_net(x_erased, training=False)

      loss_support = self.loss_support(ps, y - label_mask)

      loss = self.compiled_loss(y, p, sample_weight, regularization_losses=self.losses)
      loss = loss + alpha * loss_support
      loss = tf.nn.compute_average_loss(loss, global_batch_size=self.batch_size)

    trainable_vars = self.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)
    self.optimizer.apply_gradients(zip(gradients, 

    return self.compute_metrics(x, y, p, sample_weight)


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source