'Cannot use set_shape with tf.function

A gradient is a nested list of tensors. I want to get the total number of elements in the gradient, and record this number as an int. However, I don't know how to do this in tf.function.

import tensorflow as tf
grad = [tf.ones((500,200)), tf.ones((200,)), tf.ones((200,1))]
def test(grad):
    m = tf.cast(0, tf.int32)
    for i in grad:
        m = m + tf.math.reduce_prod(tf.shape(i))
    out = tf.zeros((m,))
    out.set_shape((m,))
    return out

The code works as intended in eager mode. If you apply tf.function, you will get the following error

TypeError: Dimension value must be integer or None or have an index method, got value '<tf.Tensor 'add_2:0' shape=() dtype=int32>' with type '<class 'tensorflow.python.framework.ops.Tensor'>'

The issue is that 'm' should be <tf.Tensor: shape=(), dtype=int32, numpy=100400> but it is <tf.Tensor 'add_2:0' shape=() dtype=int32>.



Solution 1:[1]

Per Jirayu here is the workaround

import tensorflow as tf
grad = [tf.ones((500,200)), tf.ones((200,)), tf.ones((200,1))]
@tf.function
def test(grad):
    m = tf.cast(0, tf.int32)
    for i in grad:
        m = m + tf.math.reduce_prod(tf.shape(i))
    m = tf.zeros((m,)).shape[0]  # this m can be used to define shape
    return out

Solution 2:[2]

Calculate and pre-calculation variables.

Pre-calculation methods that reduce limits of computation as similar tasks.

[ Sample ]:

import tensorflow as tf
grad = [tf.ones((500,200)), tf.ones((200,)), tf.ones((200,1))]

@tf.function
def test(grad):
    m = tf.cast(0, tf.int32)
    m = tf.add(m, 0)
    for i in grad:
        m = m + tf.math.reduce_prod(tf.shape(i))                    
        # Compute of tf.math.multiply
    out = tf.zeros((int(m),))
    out.set_shape((out.shape[0],))
    
    return out
    
print( test( grad ) )

[ Output ]: Sample

Solution 3:[3]

You don't have to create a zero tensor just to get the number of elements in the gradient.

grad = [tf.ones((500,100)), tf.ones((200,)), tf.ones((100,1,3))]
@tf.function
def test(grad):
    m = tf.reduce_sum([tf.reduce_prod(i.shape) for i in grad])
    return m

test(grad)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Taw
Solution 2 General Grievance
Solution 3