This issue was already discovered by @fehiepsi, and discussed in https://github.com/uber/pyro/pull/1678, but I was recently bitten by it again. As such, this is to track this issue and discuss any solutions / fixes, if there is indeed something that can be fixed.
e.g. model that works very well with default tensor type torch.DoubleTensor (as set currently), but is slow with torch.FloatTensor. The reason for the slowness is that during the adaptation, we keep decreasing the step size to a very small value. This results in very slow mixing of the chains, as can be seen from the diagnostics.
import pymc3 as pm
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import MCMC, NUTS
import theano.tensor as tt
import torch
torch.set_printoptions(precision=10)
torch.set_default_tensor_type('torch.DoubleTensor')
def pm_model(data):
with pm.Model() as model:
p = pm.Beta('p', 1., 1.)
p_print = tt.printing.Print('p')(p)
pm.Binomial('obs', data['n'], p_print, observed=data['x'])
return model
def pyro_model(data):
p = pyro.sample('p', dist.Beta(1., 1.))
pyro.sample('obs', dist.Binomial(data['n'], p), obs=data['x'])
def get_samples_pymc(data, num_samples=200, warmup_steps=200):
data = {k: v.numpy() for k, v in data.items()}
with pm_model(data):
trace = pm.sample(draws=num_samples, tune=warmup_steps, chains=2)
def get_samples_pyro(data, num_samples=200, warmup_steps=200):
nuts_kernel = NUTS(pyro_model,
adapt_step_size=True,
adapt_mass_matrix=True,
jit_compile=True,
full_mass=True)
mcmc_run = MCMC(nuts_kernel,
num_samples=num_samples,
warmup_steps=warmup_steps,
num_chains=2).run(data)
print(mcmc_run.marginal(sites=['p']).diagnostics())
data = {'n': torch.tensor(5000000.), 'x': torch.tensor(3849.)}
#get_samples_pymc(data)
get_samples_pyro(data)
It seems that the reason is that our precision when using float tensors is too low - the acceptance rate with FloatTensor is low, and we end up massively decreasing the step size which in turn leads to slow sampling and mixing. This could point to:
It will be nice if there was some way for us to at least detect and throw a warning to users, to help them debug such cases.
~Additionally, I found that in some larger models, we were able to make much faster progress by decreasing the tree depth during adaptation, to say, 6, without affecting the inference results. We could consider doing so by default (i.e. have a different tree depth during adaptation which can be overridden by the user).~ [This is a separate issue that is worth discussing separately, and is not relevant to this model].
Max of tree_depth in PyMC3 of this model is 2. You can set max_tree_depth=2 to Pyro MCMC to make it fast!
I think that the reason for why Pyro NUTS uses high number of tree_depth is it cannot meet the turning condition due to the extremely small curvature of posterior near the MAP point (except at the MAP point). I guess the precision plays an important role here.
About using a smaller tree_depth during adaptation, I think that it is unnecessary. If during adaptation, the model learns a good mass matrix, then it will have smaller tree_depth after adaptation. So it is enough to reduce max_tree_depth for the whole sampling run.
I have stopped doing benchmark with PyMC3 for a while. You can verify that PyMC3 gives a completely wrong answer for this model with
data["x"] = torch.tensor([3849., 384., 422., 3232., 1123., 2324.]
(I pick them randomly) while with max_tree_depth=2, Pyro NUTS gives good answer with FloatTensor and correct answer with DoubleTensor.
Max of tree_depth in PyMC3 of this model is 2. You can set max_tree_depth=2 to Pyro MCMC to make it fast!
It would be preferable if our default options handled simple models without expecting the user to fine-tune hyper-parameters. This is isolated from a larger model where it can get tricky debugging why sampling is slow. Also, while this wasn't intended to be a comparison with PyMC3 (this issue is relevant on its own), PyMC3 isn't setting max_tree_depth to 2. The question is - why do we need to explore a much deeper tree if 2 should suffice? It could be some numerical instability issue in torch distributions or transforms, but it makes Pyro appear much slower.
If during adaptation, the model learns a good mass matrix, then it will have smaller tree_depth after adaptation.
While I agree that the tree_depth can be expected to be lower even for the sampling stage in most models, why artificially constrain it? On the other hand, decreasing the tree depth (to something small like 4-5) will allow for faster adaptation -- this needs to be researched further though, I am not sure if it might result in us learning a bad step size or mass matrix in certain cases.
You can verify that PyMC3 gives a completely wrong answer for this model with
What is data[x] - are you sampling from the posterior predictive? I do not notice anything wrong with the sampling.
What is data[x] - are you sampling from the posterior predictive? I do not notice anything wrong with the sampling.
I mean that
data = {'n': torch.tensor(5000000.), 'x': torch.tensor([3849., 384., 422., 3232., 1123., 2324.])}
I don't think that setting max_tree_depth considered fine tuning. It is much easier to change than reducing warmup_steps. And we only do it to make things faster.
the tree_depth can be expected to be lower even for the sampling stage in most models, why artificially constrain it?
I don't have my own reason to change it. It is only useful to me if we need speed.
why do we need to explore a much deeper tree if 2 should suffice?
How can we know if 2 is enough? At least, for this simple model and FloatTensor, it is not enough because trees doesn't turn.
I have reduced max_tree_depth to 5 during adaptation phase, but after that, it is still slow (trees doesn't turn).
Improving the precision issue (if possible) can be a solution. A small change in p should not lead to a large change in log_prob (to get reasonable accept_prob). Working in this direction is a good idea to improve NUTS IMO.
One disadvantage is we don't have a similar library to make a benchmark. PyMC3 is unreliable (to me) and Stan works on double precision.
data = {'n': torch.tensor(5000000.), 'x': torch.tensor([3849., 384., 422., 3232., 1123., 2324.])}
I still don't see this issue - sampling seems to work very well. If you further increase number of samples and warmup steps to 300, the convergence diagnostics look good. Did you add shape=6 to pm.Beta?
Sorry, I think I might have caused some confusion, because I think there are at least two independent issues here. So let me try to clarify (let us stick to the simple univariate example that I posted above to focus on a single issue):
torch.FloatTensor, but what is concerning is that it manifests with the default settings on such a simple model. Note that this isn't just a question of slowness - the resulting chains haven't mixed, and you can see this from our convergence diagnostics. Even if you increase the number of samples to 300 in the above example, (r_hat = 75). This is expected as we are taking way smaller step size than needed. Relatedly,
How can we know if 2 is enough?
2 is enough if we learn a reasonable step size and mass matrix, which we don't. Not saying that we should resort to max_tree_depth=2 above though, just that if there was no precision issue we wouldn't be having such a small step size and requiring to go beyond 2. We don't need to look at PyMC for that - for the model above (the simple univariate one that I pasted), you'll see that the tree size doesn't need to go above 2 (a few times it goes to 3 during adaptation) if you use DoubleTensor.
Did you add
shape=6topm.Beta?
No, I didn't. I just used the same PyMC3 model. It is weird that while it only needs 20 warmup steps for Pyro NUTS to get to the typical set; it requires 300 warmup steps for PyMC3. I guess trees in PyMC3 turns so quick that it stops exploring more (or I used PyMC3 in a wrong way) .
2 is enough if we learn a reasonable step size and mass matrix, which we don't.
It is expected that samples will be correlated because we have Delta posterior (I hope that I am wrong). And because it is Delta distribution, good "step_size" or "inverse_mass_matrix" (which represents for variance of p samples) should be extremely small (when we meet the typical set: which is MAP point). Let's forget about inverse_mass_matrix (set adapt_mass_matrix to False) because we just have a single dimension latent variable. I did some tests for step_size. It is enough to use warmup=20 to get to the typical set. After that, I set step_size to a "guess-to-be-good" value. Here are what I observed:
I think that the problem is not about learning good step_size or inverse_mass_matrix, but about how to get good precision for pe (or its grad). This will have two benefits: accept_prob will be reasonable, and velocities will be reasonable (so we can meet turning condition, which is the core of NUTS).
I think that the problem is not about learning good step_size or inverse_mass_matrix, but about how to get good precision for pe (or its grad).
That is precisely the issue and the reason for opening this issue. Just to emphasize the causal link - it is due to bad precision that we end up using a lower step size than is needed for this model. And it is due to a very low step size that we end up needing to explore the tree to a much greater depth resulting in slow sampling and mixing. I didn't mean to suggest that there is an issue with our adaptation scheme at all. As I mentioned earlier, it is most likely an issue with numerical instability either in distribution's log prob, gradient computation or the transform code.
EDIT: Modified the title to reflect the issue better.
Totally agree with your causal link.
@neerajprad What is your current guess as to the location of the precision issue (Beta value, Beta grad, transform code)? The transform code should be simply torch.exp() or torch.sigmoid(), and I doubt those have numerical issues. There are some other numerical issues with Beta.sample() that might surface during initialization in HMC.
You are right that the transform code is likely not at play here. ~Any numerical issues in Beta.sample() shouldn't affect us since we don't exercise the .sample() method, just the .log_prob() method when computing the potential energy.~ (EDIT: This could impact the initial trace setting, but it is not an issue with this example and we observe a low acceptance even after we reach the typical set as @fehiepsi noted above). My first guess would be in computing the gradient of the Beta distribution (maybe gradient of torch.lgamma, since everything else seems straightforward) -- small discrepancies there could add up leading to a diverging hamiltonian trajectory. Will investigate further and report back.
The following script shows the precision issue I guess (or it is acceptable).
import torch
import torch.distributions as tdist
def f(x, step_size, dtype=torch.float):
x = torch.tensor(x + step_size, requires_grad=True, dtype=dtype)
x_constrained = tdist.transform_to(tdist.constraints.unit_interval)(x)
log_prob = tdist.Binomial(5000000, x_constrained).log_prob(torch.tensor(3849., dtype=dtype))
return log_prob.detach().item(), torch.autograd.grad(log_prob, x)[0].item()
x = -8
print("=== FloatTensor")
for i in range(4, 9):
step_size = 10 ** (-i)
print(step_size, f(x, step_size))
print("=== DoubleTensor")
for i in range(4, 9):
step_size = 10 ** (-i)
print(step_size, f(x, step_size, torch.double))
Output (step_size, log_prob, grad)
=== FloatTensor
0.0001 (-1020.0, 2172.0)
1e-05 (-1024.0, 2172.0)
1e-06 (-1020.0, 2172.5)
1e-07 (-1024.0, 2172.0)
1e-08 (-1024.0, 2172.0)
=== DoubleTensor
0.0001 (-1031.3999517038465, 2172.081720456481)
1e-05 (-1031.5954458490014, 2172.2325857002284)
1e-06 (-1031.6149960160255, 2172.247671478428)
1e-07 (-1031.6169510409236, 2172.2491800487046)
1e-08 (-1031.6171465367079, 2172.2493309061974)
And when we move to near the typical set: x = -7.2 (inverse transform of 3849/5000000), we get
=== FloatTensor
0.0001 (4.0, 118.5)
1e-05 (4.0, 119.00000762939453)
1e-06 (8.0, 118.50000762939453)
1e-07 (4.0, 119.00000762939453)
1e-08 (4.0, 119.00000762939453)
=== DoubleTensor
0.0001 (-6.909695975482464, 118.48307606950402)
1e-05 (-6.920374557375908, 118.81855701655151)
1e-06 (-6.9214440658688545, 118.85210345312953)
1e-07 (-6.921551033854485, 118.85545808076859)
1e-08 (-6.921561725437641, 118.85579354315996)
Note that 3849/5000000 = 0.0007698, which is not a so-small probability even in single precision. Four things we can observe:
Thanks for digging into this, @fehiepsi. I'll play around with your script a bit to see what's happening here.
I think we can fix the precision issue by using probs instead of logits in computing log_prob of Binomial distribution
class PyroBinomial(dist.Binomial):
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
log_factorial_n = torch.lgamma(self.total_count + 1)
log_factorial_k = torch.lgamma(value + 1)
log_factorial_nmk = torch.lgamma(self.total_count - value + 1)
return (log_factorial_n - log_factorial_k - log_factorial_nmk +
value * self.probs.log() + (self.total_count - value) * torch.log1p(-self.probs))
The discontinuity is solved with this version. HMC also runs well with this version. But the positiveness of log_prob is still happened (as long as log_prob is continuous, this issue does not matter to HMC).
Previously, we use this version in Pyro. I don't know why we compute log_prob based on logits in PyTorch upstream. I guess the PyTorch version is more stable when Binomial is initialized with logits instead of probs. cc @alicanb
I guess the PyTorch version is more stable when Binomial is initialized with logits instead of probs.
That was the idea, i.e. not having to use probs when the Binomial was defined using logits. Even in that version, the discontinuity can simply be solved by using - self.total_count * (max_val - torch.log1p((self.logits + 2 * max_val).exp()) in the .log_prob method.
I think what might be happening is when we convert from probs to logits there is a loss of precision and therefore the logit version gives bad results when the Binomial is initialized with probs. If that's the case, it might affect other distributions too. In any case, I am really happy to see this getting tracked down, and knowing that there is a simple fix that you mention above!
@fehiepsi - There is an issue with the log_prob term actually, not the conversion into logits. I will send a fix shortly to torch distributions.
Nice! :dancing_men:
@fehiepsi - Thank you for discovering, and helping debug this!
Closed as the bug is fixed in pytorch master.