'tensorflow_probability: TransformedDistribution not accepting event_shape and batch_shape arguments

In version 0.11.0 of Tensorflow Probability, I can define a TransformedDistribution as follows, indicating event and batch shape:

mvn = tfd.TransformedDistribution(normal, scale_lin_op, event_shape=[4], batch_shape=[2])

However, in the current version (0.12.1 as of the writing of this post), the event_shapeand batch_shape arguments seem to have been deprecated, as the previous line produces the error:

TypeError: __init__() got an unexpected keyword argument 'event_shape'

QUESTION: How should the event and batch shapes be overridden in the current versions? Is there some workaround or known alternative?


Note for reproducibility: This is how I define a simple distribution and a bijector:

import tensorflow_probability as tfp
import tensorflow as tf
tfb = tfp.bijectors
tfd = tfp.distributions

# Define simple normal distribution
normal = tfd.Normal(loc =0., scale=1.)

# Define bijector based on a linear operator 
tril = tf.random.normal((2, 4, 4))
scale_low_tri = tf.linalg.LinearOperatorLowerTriangular(tril)
scale_lin_op = tfb.ScaleMatvecLinearOperator(scale_low_tri)

Remark: I know that doing mvn2 = tfd.MultivariateNormalLinearOperator(loc=0, scale=scale_low_tri) is equivalent and yields the correct batch and event shapes, but what I would like to know is what is the "new" way of overriding the event and batch shape when using TransformedDistribution.



Solution 1:[1]

You can use tfd.Sample to sample extra event dimensions.

At the time of the change, there wasn't a suggested way to override the underlying batch shape, other than "use larger parameters for the distribution" like you have identified. However, the newly added (see pip install tfp-nightly), tfd.BatchBroadcast enables sampling a larger batch shape.

If you can only depend on a release, your best bet is what you have (or directly MultivariateNormalTriL and skip the LinearOperator bit).

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Brian Patton