'Warning: Variables were used in a Lambda layer's call. Results in not being trainable, but layer is already subclassed

enter image description here

I want to create this custom ConvRNN2D layer.
Therefore I copied the code for the ConvLSTM2D layer by tensorflow.
I adopted this code unitl it fits now the figure above.

The problem:

  • Tensorflow outputs the warning that the kernels and bias is not tracked.
  • The Model do not count any trainable variables in the parameter column for the layer.
  • The Model do not list the kernels and bias in "model.trainable_variables"
  • All this even results in the EXACT same problem when I copy the tensorflow code (ConvLSTM2D) without ANY changes.

My questions:

  • How do I make the kernels and bias trackable by tensorflow so the layer becomes trainable?
  • Why does even ConvLSTM2D has the SAME problem when I copy and paste the code without touching it?
  • Why does tensorflow list the internal layer structure in model.summary()?

Expectations vs reality:

enter image description here enter image description here

Code?

Test model

import numpy
import tensorflow
from tensorflow.keras.layers import Conv2D, TimeDistributed

data = numpy.random.rand(1, 7, 128, 128, 5)
inputLayer = tensorflow.keras.Input(shape=[None, None, None, data.shape[-1]])
encode = TimeDistributed(Conv2D(32, kernel_size=(3, 3), padding='same'))(inputLayer)
recurrent = AECRNN(32, (3,3), padding='same', return_sequences=True)(encode)

model = tensorflow.keras.Model(inputLayer, recurrent)
model.compile()
model.summary()

My ConvRNN2D layer

  • 80% of the lines are boilderplate and copied from ConvLSTM2D.
  • The only major changes are made are in the cell-class in the function "build" and "call".
from tensorflow.keras.layers import Layer

from tensorflow.python.ops import array_ops
from tensorflow.python.keras import backend
from tensorflow.python.keras import activations
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.layers.convolutional_recurrent import ConvRNN2D


class AECRNNCell(Layer):
  def __init__(self,
               filters,
               kernel_size,
               strides=(1, 1),
               padding='same',
               data_format=None,
               dilation_rate=(1, 1),
               activation='relu',
               use_bias=True,
               kernel_initializer='orthogonal',
               bias_initializer='zeros',
               unit_forget_bias=True,
               kernel_regularizer=None,
               bias_regularizer=None,
               kernel_constraint=None,
               bias_constraint=None,
               **kwargs):
    super(AECRNNCell, self).__init__(**kwargs)

    self.filters = filters

    self.state_size = (self.filters, self.filters)
    self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
    self.kernel_shape = self.kernel_size + (self.filters, self.filters)

    self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
    self.padding = conv_utils.normalize_padding(padding)
    self.data_format = conv_utils.normalize_data_format(data_format)
    self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2, 'dilation_rate')
    self.activation = activations.get(activation)
    
    self.kernel_initializer = initializers.get(kernel_initializer)
    self.kernel_regularizer = regularizers.get(kernel_regularizer)
    self.kernel_constraint = constraints.get(kernel_constraint)

    self.bias_initializer = initializers.get(bias_initializer)
    self.bias_regularizer = regularizers.get(bias_regularizer)
    self.bias_constraint = constraints.get(bias_constraint)

    self.kernel = self.add_weight(
      shape=self.kernel_shape, name='kernel',
      initializer=self.kernel_initializer,
      regularizer=self.kernel_regularizer,
      constraint=self.kernel_constraint)

    self.recurrent_kernel = self.add_weight(
        shape=self.kernel_size + (self.filters, self.filters * 2),
        name='recurrent_kernel',
        initializer=self.kernel_initializer,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint)

    self.use_bias = use_bias
    self.unit_forget_bias = unit_forget_bias
    if self.use_bias:
      self.bias = self.add_weight(
        shape=(self.filters,), name='bias',
        initializer=self.bias_initializer,
        regularizer=self.bias_regularizer,
        constraint=self.bias_constraint)
    else:
      self.bias = None
    self.built = True

  def call(self, inputs, states, training=None):

    result = self.input_conv(inputs, self.kernel, self.bias if self.use_bias else None)
    result = backend.concatenate([result, states[0]])

    kernel2, kernel3 = array_ops.split(self.recurrent_kernel, 2, axis=3)
    result = self.recurrent_conv(inputs, kernel2)
    result = self.recurrent_conv(inputs, kernel3)

    return result, [result, inputs]

  def input_conv(self, x, w, b=None, padding='valid'):
    conv_out = self.recurrent_conv(x, w)
    if b is not None:
      conv_out = backend.bias_add(conv_out, b, data_format=self.data_format)
    return conv_out

  def recurrent_conv(self, x, w):
    return backend.conv2d(x, w,
      strides=self.strides,
      padding=self.padding,
      data_format=self.data_format,
      dilation_rate=self.dilation_rate)

class AECRNN(ConvRNN2D):
  def __init__(self,
               filters,
               kernel_size,
               strides=(1, 1),
               padding='same',
               data_format=None,
               dilation_rate=(1, 1),
               activation='relu',
               use_bias=True,
               kernel_initializer='orthogonal',
               bias_initializer='zeros',
               unit_forget_bias=True,
               kernel_regularizer=None,
               bias_regularizer=None,
               activity_regularizer=None,
               kernel_constraint=None,
               bias_constraint=None,
               return_sequences=False,
               return_state=False,
               go_backwards=False,
               stateful=False,
               **kwargs):
    cell = AECRNNCell(filters=filters,
                          kernel_size=kernel_size,
                          strides=strides,
                          padding=padding,
                          data_format=data_format,
                          dilation_rate=dilation_rate,
                          activation=activation,
                          use_bias=use_bias,
                          kernel_initializer=kernel_initializer,
                          bias_initializer=bias_initializer,
                          unit_forget_bias=unit_forget_bias,
                          kernel_regularizer=kernel_regularizer,
                          bias_regularizer=bias_regularizer,
                          kernel_constraint=kernel_constraint,
                          bias_constraint=bias_constraint,
                          dtype=kwargs.get('dtype'))
    super(AECRNN, self).__init__(cell,
                                     return_sequences=return_sequences,
                                     return_state=return_state,
                                     go_backwards=go_backwards,
                                     stateful=stateful,
                                     **kwargs)
    self.activity_regularizer = regularizers.get(activity_regularizer)

  def call(self, inputs, mask=None, training=None, initial_state=None):
    return super(AECRNN, self).call(inputs,
                                        mask=mask,
                                        training=training,
                                        initial_state=initial_state)


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source