While attempting to perform inference on a model (code located below) using NUTS/MCMC, after several (~100) iterations, I get the error: ValueError: only one element tensors can be converted to Python scalars. When adding a print statement to relevant variable, which is a draw from a Categorical distribution, it became clear that the variable was now a two-element tensor, hence the issue. However, I don't see why a draw from a Categorical distribution should be returning a two element tensor.
I am running this in Google Colab, using Python 3.6, pyro-ppl-0.5.1, torch 1.3.0.
import torch
import pyro.distributions as dist
import pyro
from pyro.infer import NUTS, MCMC
from torch.distributions import constraints
from pyro import poutine
data1 = dist.MultivariateNormal(-5 * torch.ones(2), torch.eye(2)).sample([100])
data2 = dist.MultivariateNormal(5 * torch.ones(2), torch.eye(2)).sample([100])
data = torch.cat((data1, data2))
pyro.enable_validation(True)
N = len(data)
T = 2
def mix_weights(beta):
weights = torch.zeros(beta.shape[0] + 1)
for t in range(beta.shape[0]):
weights[t] = beta[t] * torch.prod(1. - beta[:t], dim=0)
weights[beta.shape[0]] = 1. - torch.sum(weights)
return weights
def model(data):
alpha = 1.2
with pyro.plate("beta_plate", T-1):
beta = pyro.sample("beta", dist.Beta(1, alpha))
with pyro.plate("mu_plate", T):
mu = pyro.sample("mu", dist.MultivariateNormal(torch.zeros(2), 5 * torch.eye(2)))
for i in range(N):
z = pyro.sample("z_{}".format(i), dist.Categorical(mix_weights(beta)))
print (z.numpy())
pyro.sample("obs_{}".format(i), dist.MultivariateNormal(mu[z.item()], torch.eye(2)), obs=data[i])
nuts_kernel = NUTS(model, adapt_step_size=True)
mcmc = MCMC(nuts_kernel, num_samples=500, warmup_steps=300)
mcmc.run(data)
samples = mcmc.get_samples()
print(samples)
Output:
Warmup: 0%| | 0/800 [00:00, ?it/s]0
0
0
0
0
0
1
1
0
1
0
1
0
0
1
0
1
0
0
0
0
0
0
0
0
0
0
1
0
0
1
0
1
0
0
0
0
0
0
0
0
0
0
0
0
1
0
0
0
0
0
0
1
0
0
0
1
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
1
1
0
0
0
0
0
1
0
0
0
0
0
0
0
0
0
0
0
0
1
0
0
1
0
0
1
0
0
0
0
0
1
0
0
0
1
0
1
1
0
1
0
0
0
1
0
0
0
0
0
1
0
0
0
0
0
0
1
1
0
0
0
0
0
0
0
1
1
0
1
0
0
0
0
0
0
0
0
0
1
0
0
1
0
0
0
0
0
0
0
0
0
0
0
0
0
0
1
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
[[0]
[1]]
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
142 try:
--> 143 ret = self.fn(*args, **kwargs)
144 except (ValueError, RuntimeError):
16 frames
ValueError: only one element tensors can be converted to Python scalars
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
<ipython-input-19-63983afdb4c3> in model(data)
22 z = pyro.sample("z_{}".format(i), dist.Categorical(mix_weights(beta)))
23 print (z.numpy())
---> 24 pyro.sample("obs_{}".format(i), dist.MultivariateNormal(mu[z.item()], torch.eye(2)), obs=data[i])
25
26 nuts_kernel = NUTS(model, adapt_step_size=True)
ValueError: only one element tensors can be converted to Python scalars
Trace Shapes:
Param Sites:
Sample Sites:
beta_plate dist |
value 1 |
beta dist 1 |
value 1 |
mu_plate dist |
value 2 |
mu dist 2 | 2
value 2 | 2
z_0 dist |
value 2 1 |
@m-k-S In Pyro, NUTS marginalizes discrete latent variable, so you should use mu[z] in place of mu[z.item()]. However, you will need more batteries to enumerate more than 25 variables. For your model, I think that using
with pyro.plate("data", N):
z = pyro.sample("z", dist.Categorical(mix_weights(beta)))
pyro.sample("obs", dist.MultivariateNormal(mu[z], torch.eye(2)), obs=data)
is enough and will be much faster.
@fehiepsi do I understand correctly that the first few scalar samples were drawn during NUTS initialization, not during NUTS iteration, and that in fact the non-scalar site is encountered on the first NUTS step?
@fritzo Yes, scalar samples are drawn during inspecting the model to detect its structure. After that, enumerate will be activated by wrapping the model with poutine.enum. Non-scalar site is encountered at the initialization phase too.
Thanks @fehiepsi, this worked perfectly.
I have the same problem with using a discrete distribution, but I am not using plates. Once enumeration starts, I see the multiple values, but it stops taking samples. It runs twice with the multiple values and then quits without an error message. Then mcmc has no_samples.
@jmugan If your model does not have continuous latent variables, MCMC will collect nothing. Could you create a separate issue if that is not the case?
It does have continuous variables as well. I'll create a small example program to illustrate the problem and open a separate issue. Thanks!