I'm building a model similar to LDA, but with a time-dependence on the topic proportions.
Here's the model:
@config_enumerate(default='parallel')
@poutine.broadcast
def model(data):
beta = pyro.sample('beta', dist.Dirichlet(torch.ones(K, V)).independent(1))
# print('model beta:\n', beta)
for t in pyro.irange('timeslice', T):
if t == 0:
theta = pyro.sample('theta_t{}'.format(t), dist.Normal(torch.ones(K), SCALE).independent(1))
else:
theta = pyro.sample('theta_t{}'.format(t), dist.Normal(theta_prev, SCALE).independent(1))
theta = torch.tensor([max(i, 0.001) for i in theta])
theta = theta / sum(theta)
for d in pyro.irange('docs_t{}'.format(t), N[t]):
z = pyro.sample('z_t{}_d{}'.format(t, d), dist.Categorical(theta))
with pyro.iarange('words_t{}_d{}'.format(t, d), L[t][d]):
pyro.sample('w_t{}_d{}'.format(t, d), dist.Categorical(beta[z,:]), obs=data[t][d])
theta_prev = theta
When I run SVI with dummy data, it runs fine. However, when I increase the vocabulary size (seems to be at 14) I get ValueError: The value argument must be within the support when calculating the log_prob of beta. This seems to be because the tensor is not adding to 1 (to within 1e-6).
I recreate the error explicitly in the guide, printing out the tensor that is creating the error:
beta_q:
tensor([[ 0.0272, 0.0025, 0.0088, 0.2814, 0.0172, 0.0232, 0.1008,
0.0317, 0.0171, 0.0292, 0.0171, 0.0790, 0.0006, 0.0717,
0.0118, 0.0018, 0.0604, 0.0406, 0.1116, 0.0494, 0.0167],
[ 0.0227, 0.0521, 0.0349, 0.2412, 0.0265, 0.0291, 0.0314,
0.0229, 0.0444, 0.0044, 0.0052, 0.0720, 0.0585, 0.0319,
0.0550, 0.0325, 0.0646, 0.0005, 0.0970, 0.0050, 0.0681]])
test:
Independent()
sample:
tensor([[ 1.1921e-07, 1.1921e-07, 1.1921e-07, 6.9644e-02, 1.1921e-07,
1.1921e-07, 1.3194e-07, 6.8583e-05, 1.1921e-07, 1.1921e-07,
1.1921e-07, 9.8113e-05, 1.1921e-07, 9.6791e-03, 1.1921e-07,
1.1921e-07, 1.1700e-02, 1.1921e-07, 9.0417e-01, 4.6380e-03,
1.1921e-07],
[ 1.1921e-07, 8.1971e-07, 1.1921e-07, 4.9934e-02, 1.1921e-07,
1.1921e-07, 1.1921e-07, 1.1921e-07, 1.0439e-01, 1.1921e-07,
1.1921e-07, 9.4799e-03, 8.3286e-01, 1.1921e-07, 2.0750e-04,
1.1921e-07, 3.1230e-03, 1.1921e-07, 7.0881e-06, 1.1921e-07,
1.4453e-06]])
File "main_tPMM.py", line 107, in guide
print('test.log_prob(sample):\n', test.log_prob(sample))
File "/anaconda3/lib/python3.6/site-packages/torch/distributions/independent.py", line 78, in log_prob
log_prob = self.base_dist.log_prob(value)
File "/anaconda3/lib/python3.6/site-packages/torch/distributions/dirichlet.py", line 73, in log_prob
self._validate_sample(value)
File "/anaconda3/lib/python3.6/site-packages/torch/distributions/distribution.py", line 221, in _validate_sample
raise ValueError('The value argument must be within the support')
ValueError: The value argument must be within the support
Trace Shapes:
Param Sites:
beta_q 2 21
Sample Sites:
beta dist | 2 21
value | 2 21
Trace Shapes:
Param Sites:
beta_q 2 21
Sample Sites:
beta dist | 2 21
value | 2 21
However, if I take those variables, beta_q and sample and run
test = dist.Dirichlet(beta_q).independent(1)
test.log_prob(sample)
directly in the terminal no error is thrown.
Python 3.6.5 from Anaconda
Pyro version 0.2.1+89fd6d3a
Pytorch version 0.4.0
Can you try using the latest version of Pyro's dev branch and the most recent PyTorch 1.0 nightly build and see if the issue is still there?
this may also be due to stochastic learning difficulties (e.g. too high a learning rate can take you bad places)
Hi @kgero I believe Pyro turns on validation whereas the command line version does not. Since this looks like an ignorable precision issue, you can disable validation at that one sample by setting validate_args=False:
beta = pyro.sample('beta',
- dist.Dirichlet(torch.ones(K, V))
+ dist.Dirichlet(torch.ones(K, V), validate_args=False)
.independent(1))
Updating the version did not fix the problem. My version is now
torch.__version__ 1.0.0.dev20181206
pyro.__version__ 0.2.1+9e76df6c
However, introducing validate_args=False did. Still seems strange that it should fail the sum to 1 -- do you want me to close the issue?
I'm ok leaving the issue open for now. If other users see the same error, we may loosen the validation upstream in PyTorch.
Hello! I met the same problem just now, and after I commented the line pyro.enable_validation(True), my program works well
We observed some numerical stability issues in numpyro, which includes some numerical issues with StickBreaking transform. I will address these problems upstream in PyTorch.
I think that the issue is due to the precision error sum(x / x.sum(-1) -1) > 1e-6 (can happen in float32).
But it seems that the issue is fixed with the latest release of PyTorch. More context can be found at this issue and this PR. I tried to run the following code many times (in both GPU and CPU) and all the samples pass the validation check.
torch.distributions.constraints.simplex.check(
torch._sample_dirichlet(torch.randn(100000, 100).abs())).all()
Please open the topic again if this is still an issue.