'Warning: Variables were used in a Lambda layer's call. Results in not being trainable, but layer is already subclassed
I want to create this custom ConvRNN2D layer.
Therefore I copied the code for the ConvLSTM2D layer by tensorflow.
I adopted this code unitl it fits now the figure above.
The problem:
- Tensorflow outputs the warning that the kernels and bias is not tracked.
- The Model do not count any trainable variables in the parameter column for the layer.
- The Model do not list the kernels and bias in "model.trainable_variables"
- All this even results in the EXACT same problem when I copy the tensorflow code (ConvLSTM2D) without ANY changes.
My questions:
- How do I make the kernels and bias trackable by tensorflow so the layer becomes trainable?
- Why does even ConvLSTM2D has the SAME problem when I copy and paste the code without touching it?
- Why does tensorflow list the internal layer structure in model.summary()?
Expectations vs reality:
Code?
Test model
import numpy
import tensorflow
from tensorflow.keras.layers import Conv2D, TimeDistributed
data = numpy.random.rand(1, 7, 128, 128, 5)
inputLayer = tensorflow.keras.Input(shape=[None, None, None, data.shape[-1]])
encode = TimeDistributed(Conv2D(32, kernel_size=(3, 3), padding='same'))(inputLayer)
recurrent = AECRNN(32, (3,3), padding='same', return_sequences=True)(encode)
model = tensorflow.keras.Model(inputLayer, recurrent)
model.compile()
model.summary()
My ConvRNN2D layer
- 80% of the lines are boilderplate and copied from ConvLSTM2D.
- The only major changes are made are in the cell-class in the function "build" and "call".
from tensorflow.keras.layers import Layer
from tensorflow.python.ops import array_ops
from tensorflow.python.keras import backend
from tensorflow.python.keras import activations
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.layers.convolutional_recurrent import ConvRNN2D
class AECRNNCell(Layer):
def __init__(self,
filters,
kernel_size,
strides=(1, 1),
padding='same',
data_format=None,
dilation_rate=(1, 1),
activation='relu',
use_bias=True,
kernel_initializer='orthogonal',
bias_initializer='zeros',
unit_forget_bias=True,
kernel_regularizer=None,
bias_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(AECRNNCell, self).__init__(**kwargs)
self.filters = filters
self.state_size = (self.filters, self.filters)
self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
self.kernel_shape = self.kernel_size + (self.filters, self.filters)
self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
self.padding = conv_utils.normalize_padding(padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2, 'dilation_rate')
self.activation = activations.get(activation)
self.kernel_initializer = initializers.get(kernel_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_initializer = initializers.get(bias_initializer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.bias_constraint = constraints.get(bias_constraint)
self.kernel = self.add_weight(
shape=self.kernel_shape, name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.recurrent_kernel = self.add_weight(
shape=self.kernel_size + (self.filters, self.filters * 2),
name='recurrent_kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.use_bias = use_bias
self.unit_forget_bias = unit_forget_bias
if self.use_bias:
self.bias = self.add_weight(
shape=(self.filters,), name='bias',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
self.built = True
def call(self, inputs, states, training=None):
result = self.input_conv(inputs, self.kernel, self.bias if self.use_bias else None)
result = backend.concatenate([result, states[0]])
kernel2, kernel3 = array_ops.split(self.recurrent_kernel, 2, axis=3)
result = self.recurrent_conv(inputs, kernel2)
result = self.recurrent_conv(inputs, kernel3)
return result, [result, inputs]
def input_conv(self, x, w, b=None, padding='valid'):
conv_out = self.recurrent_conv(x, w)
if b is not None:
conv_out = backend.bias_add(conv_out, b, data_format=self.data_format)
return conv_out
def recurrent_conv(self, x, w):
return backend.conv2d(x, w,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate)
class AECRNN(ConvRNN2D):
def __init__(self,
filters,
kernel_size,
strides=(1, 1),
padding='same',
data_format=None,
dilation_rate=(1, 1),
activation='relu',
use_bias=True,
kernel_initializer='orthogonal',
bias_initializer='zeros',
unit_forget_bias=True,
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
**kwargs):
cell = AECRNNCell(filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
unit_forget_bias=unit_forget_bias,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
dtype=kwargs.get('dtype'))
super(AECRNN, self).__init__(cell,
return_sequences=return_sequences,
return_state=return_state,
go_backwards=go_backwards,
stateful=stateful,
**kwargs)
self.activity_regularizer = regularizers.get(activity_regularizer)
def call(self, inputs, mask=None, training=None, initial_state=None):
return super(AECRNN, self).call(inputs,
mask=mask,
training=training,
initial_state=initial_state)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|