For the same models below, the PyMC3 model finishes within a second, whereas the Pyro model has an extremely slow rate of progress. I came across this perf issue while working on a larger model, and isolated it to the example below. I am not sure if this is a perf bug, or whether PyMC3 uses some adaptation tricks, but it is worth looking into and learning from. The cause of the slowdown is the very small step size in Pyro (~10^-9), whereas in PyMC3, the step size throughout is of the order of ~ 10^-3. This is also not a case of bad initialization, the same can be observed for different random seeds.
Also refer to #1470.
import pymc3 as pm
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import MCMC, NUTS
import torch
def pm_model(data):
with pm.Model() as model:
alpha = pm.HalfCauchy('alpha', beta=2)
beta = pm.HalfCauchy('beta', beta=2)
l = pm.Gamma('l', alpha, beta)
pm.Poisson('obs', l, observed=data)
return model
def pyro_model(data):
alpha = pyro.sample('alpha', dist.HalfCauchy(2.))
beta = pyro.sample('beta', dist.HalfCauchy(2.))
l = pyro.sample('l', dist.Gamma(alpha, beta))
pyro.sample('obs', dist.Poisson(l), obs=data)
def get_samples_pymc(data, num_samples=200, warmup_steps=200):
data = data.numpy()
with pm_model(data):
trace = pm.sample(draws=num_samples, tune=warmup_steps, chains=1)
def get_samples_pyro(data, num_samples=200, warmup_steps=200):
data = data
nuts_kernel = NUTS(pyro_model,
adapt_step_size=True,
adapt_mass_matrix=True,
jit_compile=True,
full_mass=True)
mcmc_run = MCMC(nuts_kernel,
num_samples=num_samples,
warmup_steps=warmup_steps,
num_chains=1).run(data)
data = torch.tensor(4805497.)
get_samples_pymc(data) # finishes in a second.
get_samples_pyro(data) # orders of magnitude slower.
Here are some of my thoughts on this:
Anyway, the model setup is not good, so the slowness is acceptable in my opinion.
I think that this model is not a good setup for the data.
This is from a larger model from a user (originally written in PyMC3), and the above is just a minimal example that is representative of the issue. I do agree that this is not an ideal model setup, but I also think that we can use this simple case to make NUTS more robust (and awesome), both because bigger more complex examples might exhibit similar patterns, and for a better user experience. Consider for example a model that worked well on smaller observations, but became suddenly slow because the user got some new data that was way larger. In such cases, we aren't providing any warning, but effectively crawling to a halt. Both PyMC3 and Stan (which uses the same windowed adaptation scheme) are extremely fast on this, though Stan throws a few warnings about divergent transitions.
I guess we can fix this problem by initializing the initial inverse mass matrix to a small value or turning off adapt_step_size, and set step_size = 1e-4 for example.
Ideally, we would like to aim for inference to be as automatic as possible, or at least raise warnings so that users know what to do when data/model seems aberrant. Setting the step size upfront is not an option because the actual model is much larger, and it may not work for the remaining model. Likewise, it would be much better for us to start with reasonable mass matrix (somehow), rather than have the user have to guess and provide it.
I am also unsure if it is just an issue with mass matrix initialization, since the step size remains small after the adaptation too. I am still looking into this though, will update this issue if I find something interesting.
It is surprised to me that Stan is fast on this. Given only 1 data, latent l will converge to a value with very small variance. Hence alpha / beta also converges to a value, which implies that they are highly correlated.
If small step_size is a real problem, then I guess we can scale initial_mass_matrix to let step_size have a higher value at the end of get_reasonable_step_size method. Theoretically, things should be the same but I'll check if it works. Maybe scale step_size will help for precision issue or dual_averaging,...
I just run the program and observe that the small step_size happens at the initial window adaptation, not find_reasonable_step_size. I'll further see how Stan deals with this small step_size problem during adapting phase.
I'll further see how Stan deals with this small step_size problem during adapting phase.
The scale of the parameters is so vastly different, that I'm not surprised that we need a small step size, but I would have expected that the mass matrix would help neutralize some of that and make it faster.
My guess is that we need many more slower adaptation windows for mass matrix to be correctly adapted, but in the meantime, we are just stuck with such low step size as to make it impossible to make progress. Stan might have the same issues, but since it has a much faster runtime, it eventually learns a good mass matrix and is not stuck forever. And, PyMC3 probably gets around this problem by constantly adapting the mass matrix.