'Keras ValueError: Dimensions must be equal - How to pass label-dependent values to custom loss function

I have a keras model with 5 outputs. My labels include 5 values to compare these to, but also 25 additional values representing a correlation matrix for the 5 values. (Specifically, it is the inverse covariance matrix). I want a custom loss function that includes the proper covariance (see formula below). I have written one and tested that it works. Keras has no problem compiling it, but complains when I try to fit it. The issue is that it seems to require the number of model outputs match the number of values in the label.

Loss function using correlation matrix From what I have seen, most of the issues people have with custom loss functions and extra parameters involve passing additional static parameters (i.e. ones that don't change for each label). Here is my loss function with the testing code at the bottom.

    def customLoss(y_true, y_pred):
    
    batch_size = tf.shape(y_pred)[0]  # n.b. y_pred.shape[0] will not work for some reason in tf1
    print('y_pred shape: ' + str(y_pred.shape) )  # y_pred shape is (batch, 5)
    print('y_true shape: ' + str(y_true.shape) )  # y_true shape is (batch, 49)
    print('y_pred type: ' + str(type(y_pred) ) )  # y_pred shape is (batch, 5)
    print('y_true type: ' + str(type(y_true) ) )  # y_true shape is (batch, 49)
    
    # Note that y_pred only has the 5 state vector parameters while y_true contains
    # all of the labels (event, state vector, covariance matrix, inverse cov., ...)
    # We peel off the state vector and inverse covariance here which are the parts
    # we need.
    y_state = y_true[:,1:6]                         # y_state shape is now (batch, 5)
    invcov  = y_true[:,21:46]                       # invcov  shape is now (batch, 25)
    
    y_pred  = K.reshape(y_pred,  (batch_size, 5,1)) # y_pred  shape is now (batch, 5,1)
    y_state = K.reshape(y_state, (batch_size, 5,1)) # y_state shape is now (batch, 5,1)
    invcov  = K.reshape(invcov,  (batch_size, 5,5)) # invcov  shape is now (batch, 5,5)
    
    # n.b. we must use tf.transpose here an not K.transpose since the latter does not allow perm argument
    invcov = tf.transpose(invcov, perm=[0,2,1])     # invcov shape is now (batch, 5,5)
    
    # Difference between prediction and true state vectors
    y_diff = y_pred - y_state
    
    # n.b. use "batch_dot" and not "dot"!
    y_dot = K.batch_dot(invcov, y_diff)           # y_dot shape is (batch,5,1)
    y_dot = K.reshape(y_dot, (batch_size, 1, 5))  # y_dot shape is now (batch,1,5)
    y_loss = K.batch_dot(y_dot, y_diff)           # y_loss shape is (batch,1,1)
    y_loss = K.reshape(y_loss, (batch_size,))     # y_loss shape is now (batch)
    return y_loss
    
    # Test loss function
    xx = np.arange(0.1, 4.9, 0.1).tolist()  # make list of 49 values. This will be y_true (only 5+25=30 are used)
    yy = [1.0, 2.0, 3.0, 4.0, 5.0]          # make list of 5 values. This will be y_pred
    
    loss = K.eval(customLoss(K.variable([xx,xx,xx]), K.variable([yy,yy,yy])))
    print('loss shape: '    + str(loss.shape)    )
    print(loss)

The output from testing the loss is:

y_pred shape: (3, 5)
y_true shape: (3, 49)
y_pred type: <class 'tensorflow.python.ops.resource_variable_ops.ResourceVariable'>
y_true type: <class 'tensorflow.python.ops.resource_variable_ops.ResourceVariable'>
loss shape: (3,)
[644.8 644.8 644.8]

Here is the error I get when I try to compile:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-7-d1d9dff4fee8> in <module>
     36     shuffle=True,
     37     initial_epoch = epoch_loaded,
---> 38     use_multiprocessing=False
     39 )
     40 

/w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
     64   def _method_wrapper(self, *args, **kwargs):
     65     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
---> 66       return method(self, *args, **kwargs)
     67 
     68     # Running inside `run_distribute_coordinator` already.

/w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
    846                 batch_size=batch_size):
    847               callbacks.on_train_batch_begin(step)
--> 848               tmp_logs = train_function(iterator)
    849               # Catch OutOfRangeError for Datasets of unknown size.
    850               # This blocks until the batch has finished executing.

/w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    578         xla_context.Exit()
    579     else:
--> 580       result = self._call(*args, **kwds)
    581 
    582     if tracing_count == self._get_tracing_count():

/w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    625       # This is the first call of __call__, so we have to initialize.
    626       initializers = []
--> 627       self._initialize(args, kwds, add_initializers_to=initializers)
    628     finally:
    629       # At this point we know that the initialization is complete (or less

/w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    504     self._concrete_stateful_fn = (
    505         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 506             *args, **kwds))
    507 
    508     def invalid_creator_scope(*unused_args, **unused_kwds):

/w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2444       args, kwargs = None, None
   2445     with self._lock:
-> 2446       graph_function, _, _ = self._maybe_define_function(args, kwargs)
   2447     return graph_function
   2448 

/w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2775 
   2776       self._function_cache.missed.add(call_context_key)
-> 2777       graph_function = self._create_graph_function(args, kwargs)
   2778       self._function_cache.primary[cache_key] = graph_function
   2779       return graph_function, args, kwargs

/w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2665             arg_names=arg_names,
   2666             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2667             capture_by_value=self._capture_by_value),
   2668         self._function_attributes,
   2669         # Tell the ConcreteFunction to clean up its graph once it goes out of

/w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    979         _, original_func = tf_decorator.unwrap(python_func)
    980 
--> 981       func_outputs = python_func(*func_args, **func_kwargs)
    982 
    983       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    439         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    440         # the function a weak reference to itself to avoid a reference cycle.
--> 441         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    442     weak_wrapped_fn = weakref.ref(wrapped_fn)
    443 

/w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise

ValueError: in user code:

    /w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:571 train_function  *
        outputs = self.distribute_strategy.run(
    /w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:951 run  **
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:2290 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:2649 _call_for_each_replica
        return fn(*args, **kwargs)
    /w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:543 train_step  **
        self.compiled_metrics.update_state(y, y_pred, sample_weight)
    /w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/keras/engine/compile_utils.py:411 update_state
        metric_obj.update_state(y_t, y_p)
    /w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/keras/utils/metrics_utils.py:90 decorated
        update_op = update_state_fn(*args, **kwargs)
    /w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py:603 update_state
        matches = self._fn(y_true, y_pred, **self._fn_kwargs)
    /w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/keras/losses.py:1230 mean_absolute_error
        return K.mean(math_ops.abs(y_pred - y_true), axis=-1)
    /w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:984 binary_op_wrapper
        return func(x, y, name=name)
    /w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py:10103 sub
        "Sub", x=x, y=y, name=name)
    /w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:744 _apply_op_helper
        attrs=attr_protos, op_def=op_def)
    /w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py:595 _create_op_internal
        compute_device)
    /w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3327 _create_op_internal
        op_def=op_def)
    /w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:1817 __init__
        control_input_ops, op_def)
    /w/halld-scifs17exp/halld2/home/davidl/builds/Python_VENV/venv_2020.06.02/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:1657 _create_c_op
        raise ValueError(str(e))

    ValueError: Dimensions must be equal, but are 5 and 49 for '{{node sub}} = Sub[T=DT_FLOAT](model/outputs/Relu, Cast_2)' with input shapes: [1000,5], [1000,49].

So the question is how to pass additional parameters to the custom loss function that vary for each set of inputs?



Solution 1:[1]

these are the workaround to pass additional parameters to the custom loss function that vary for each set of inputs

I provide a dummy example in a regression problem

def mse(y_true, y_pred, external):

    score = K.mean(K.sqrt(y_true-y_pred))
    # make something with external...
    # I simply add to final score (no sense)

    return score + K.mean(external)


X = np.random.uniform(0,1, (1000,10))
y = np.random.uniform(0,3, (1000,49))

inp = Input((10))
true = Input((5))
external_arg = Input((49-5))
x = Dense(32, activation='relu')(inp)
pred = Dense(5)(x)

m = Model([inp,external_arg,true], pred)
m.add_loss( mse( true, pred, external_arg ) )
m.compile(loss=None, optimizer='adam')
history = m.fit([X, y[:,5:], y[:,:5]], y[:,:5], epochs=10)

# final fitted model to compute predictions
final_m = Model(inp, pred)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Marco Cerliani