'Cannot use set_shape with tf.function
A gradient is a nested list of tensors. I want to get the total number of elements in the gradient, and record this number as an int. However, I don't know how to do this in tf.function.
import tensorflow as tf
grad = [tf.ones((500,200)), tf.ones((200,)), tf.ones((200,1))]
def test(grad):
m = tf.cast(0, tf.int32)
for i in grad:
m = m + tf.math.reduce_prod(tf.shape(i))
out = tf.zeros((m,))
out.set_shape((m,))
return out
The code works as intended in eager mode. If you apply tf.function, you will get the following error
TypeError: Dimension value must be integer or None or have an index method, got value '<tf.Tensor 'add_2:0' shape=() dtype=int32>' with type '<class 'tensorflow.python.framework.ops.Tensor'>'
The issue is that 'm' should be <tf.Tensor: shape=(), dtype=int32, numpy=100400> but it is <tf.Tensor 'add_2:0' shape=() dtype=int32>.
Solution 1:[1]
Per Jirayu here is the workaround
import tensorflow as tf
grad = [tf.ones((500,200)), tf.ones((200,)), tf.ones((200,1))]
@tf.function
def test(grad):
m = tf.cast(0, tf.int32)
for i in grad:
m = m + tf.math.reduce_prod(tf.shape(i))
m = tf.zeros((m,)).shape[0] # this m can be used to define shape
return out
Solution 2:[2]
Calculate and pre-calculation variables.
Pre-calculation methods that reduce limits of computation as similar tasks.
[ Sample ]:
import tensorflow as tf
grad = [tf.ones((500,200)), tf.ones((200,)), tf.ones((200,1))]
@tf.function
def test(grad):
m = tf.cast(0, tf.int32)
m = tf.add(m, 0)
for i in grad:
m = m + tf.math.reduce_prod(tf.shape(i))
# Compute of tf.math.multiply
out = tf.zeros((int(m),))
out.set_shape((out.shape[0],))
return out
print( test( grad ) )
Solution 3:[3]
You don't have to create a zero tensor just to get the number of elements in the gradient.
grad = [tf.ones((500,100)), tf.ones((200,)), tf.ones((100,1,3))]
@tf.function
def test(grad):
m = tf.reduce_sum([tf.reduce_prod(i.shape) for i in grad])
return m
test(grad)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | Taw |
Solution 2 | General Grievance |
Solution 3 |