In the following example, the shape of the value sampled at the LKJCorrCholesky distributed sample statement differs between the model and the auto-generated guide. Is this expected?
import torch
import pyro
import pyro.distributions as dist
from pyro.contrib.autoguide import AutoDiagonalNormal
d = dist.LKJCorrCholesky(2, torch.tensor(1.))
def model():
return {'x': pyro.sample('x', d)}
guide = AutoDiagonalNormal(model)
print(model()['x'].shape) # == [2,2]
print(guide()['x'].shape) # == [1,2,2]
macos, python 3.7.3, pyro 0.3.3+bf2f9542
I dug into this a little while trying to understand what was happening. If this is a bug, it may be that the problem stems from the shape of the value returned by _unpack_latent:
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import biject_to
from pyro.contrib.autoguide import AutoDiagonalNormal
d = dist.LKJCorrCholesky(2, torch.tensor(1.))
support = d.support
def model():
return {'x': pyro.sample('x', d)}
guide = AutoDiagonalNormal(model)
latent = guide.sample_latent()
unconstrained_value = [val for site, val in guide._unpack_latent(latent)
if site['name'] == 'x'][0]
# The shape of the unconstrained value is [1,1]
print(unconstrained_value.shape) # == [1,1]
#聽Which gives a [1,2,2] tensor once transformed:
print(biject_to(support)(unconstrained_value).shape) # == [1,2,2]
# Reshaping the unconstrained value from [1,1] to [1] gives the result
# of the transform the expected shape.
print(biject_to(support)(unconstrained_value.reshape(-1)).shape) # == [2,2]
Yes, I think it is a bug at this line. Probably we should replace it by
unconstrained_shape = broadcast_shape(unconstrained_shape, batch_shape + unconstrained_shape)
The reason is for LKJ transform, unconstrained value is a vector, while constrained value is a matrix.
@fehiepsi good sleuthing! I think we'll need to account for overlapping batch_shape and unconstrained_shape, thus I suspect batch_shape + unconstrained_shape won't work. One fix that should work is
def _unpack_latent(self, latent):
...
+ constrained_shape = site["value"].shape
unconstrained_shape = self._unconstrained_shapes[name]
+ event_dim = site["fn"].event_dim + len(unconstrained_shape) - len(constrained_shape)
unconstrained_shape = broadcast_shape(unconstrained_shape,
- batch_shape + (1,) * site["fn"].event_dim)
+ batch_shape + (1,) * event_dim)
@fehiepsi do you have time to try one of the fixes?
Yup, I鈥檒l try to fix shortly.
Most helpful comment
Yup, I鈥檒l try to fix shortly.