'tensorflow_probability: TransformedDistribution not accepting event_shape and batch_shape arguments
In version 0.11.0 of Tensorflow Probability, I can define a TransformedDistribution
as follows, indicating event and batch shape:
mvn = tfd.TransformedDistribution(normal, scale_lin_op, event_shape=[4], batch_shape=[2])
However, in the current version (0.12.1 as of the writing of this post), the event_shape
and batch_shape
arguments seem to have been deprecated, as the previous line produces the error:
TypeError: __init__() got an unexpected keyword argument 'event_shape'
QUESTION: How should the event and batch shapes be overridden in the current versions? Is there some workaround or known alternative?
Note for reproducibility: This is how I define a simple distribution and a bijector:
import tensorflow_probability as tfp
import tensorflow as tf
tfb = tfp.bijectors
tfd = tfp.distributions
# Define simple normal distribution
normal = tfd.Normal(loc =0., scale=1.)
# Define bijector based on a linear operator
tril = tf.random.normal((2, 4, 4))
scale_low_tri = tf.linalg.LinearOperatorLowerTriangular(tril)
scale_lin_op = tfb.ScaleMatvecLinearOperator(scale_low_tri)
Remark: I know that doing mvn2 = tfd.MultivariateNormalLinearOperator(loc=0, scale=scale_low_tri)
is equivalent and yields the correct batch and event shapes, but what I would like to know is what is the "new" way of overriding the event and batch shape when using TransformedDistribution
.
Solution 1:[1]
You can use tfd.Sample
to sample extra event dimensions.
At the time of the change, there wasn't a suggested way to override the underlying batch shape, other than "use larger parameters for the distribution" like you have identified. However, the newly added (see pip install tfp-nightly
), tfd.BatchBroadcast
enables sampling a larger batch shape.
If you can only depend on a release, your best bet is what you have (or directly MultivariateNormalTriL
and skip the LinearOperator
bit).
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | Brian Patton |