Pyro: Pyro NUTS inference slow when compared to the same PyMC3 model

Created on 17 Dec 2018  路  5Comments  路  Source: pyro-ppl/pyro

For the same models below, the PyMC3 model finishes within a second, whereas the Pyro model has an extremely slow rate of progress. I came across this perf issue while working on a larger model, and isolated it to the example below. I am not sure if this is a perf bug, or whether PyMC3 uses some adaptation tricks, but it is worth looking into and learning from. The cause of the slowdown is the very small step size in Pyro (~10^-9), whereas in PyMC3, the step size throughout is of the order of ~ 10^-3. This is also not a case of bad initialization, the same can be observed for different random seeds.

Also refer to #1470.

import pymc3 as pm
import pyro

import pyro.distributions as dist
from pyro.infer.mcmc import MCMC, NUTS

import torch


def pm_model(data):
    with pm.Model() as model:
        alpha = pm.HalfCauchy('alpha', beta=2)
        beta = pm.HalfCauchy('beta', beta=2)
        l = pm.Gamma('l', alpha, beta)
        pm.Poisson('obs', l, observed=data)
    return model


def pyro_model(data):
    alpha = pyro.sample('alpha', dist.HalfCauchy(2.))
    beta = pyro.sample('beta', dist.HalfCauchy(2.))
    l = pyro.sample('l', dist.Gamma(alpha, beta))
    pyro.sample('obs', dist.Poisson(l), obs=data)


def get_samples_pymc(data, num_samples=200, warmup_steps=200):
    data = data.numpy()
    with pm_model(data):
        trace = pm.sample(draws=num_samples, tune=warmup_steps, chains=1)


def get_samples_pyro(data, num_samples=200, warmup_steps=200):
    data = data
    nuts_kernel = NUTS(pyro_model,
                       adapt_step_size=True,
                       adapt_mass_matrix=True,
                       jit_compile=True,
                       full_mass=True)
    mcmc_run = MCMC(nuts_kernel,
                    num_samples=num_samples,
                    warmup_steps=warmup_steps,
                    num_chains=1).run(data)


data = torch.tensor(4805497.)
get_samples_pymc(data)  # finishes in a second.
get_samples_pyro(data)  # orders of magnitude slower.
profiling usability

All 5 comments

Here are some of my thoughts on this:

  • I think that this model is not a good setup for the data. The data here is pretty large for a Poisson distribution, so the initial energy will be pretty large. Hence it makes sense to have a very small initial step_size (given that we fix inverse_mass_matrix during the initial window adaptation).
  • I guess we can fix this problem by initializing the initial inverse mass matrix to a small value or turning off adapt_step_size, and set step_size = 1e-4 for example.
  • PyMC3 does not use window adaptation, so somehow inverse mass matrix will help step_size not move to a very small value.

Anyway, the model setup is not good, so the slowness is acceptable in my opinion.

I think that this model is not a good setup for the data.

This is from a larger model from a user (originally written in PyMC3), and the above is just a minimal example that is representative of the issue. I do agree that this is not an ideal model setup, but I also think that we can use this simple case to make NUTS more robust (and awesome), both because bigger more complex examples might exhibit similar patterns, and for a better user experience. Consider for example a model that worked well on smaller observations, but became suddenly slow because the user got some new data that was way larger. In such cases, we aren't providing any warning, but effectively crawling to a halt. Both PyMC3 and Stan (which uses the same windowed adaptation scheme) are extremely fast on this, though Stan throws a few warnings about divergent transitions.

I guess we can fix this problem by initializing the initial inverse mass matrix to a small value or turning off adapt_step_size, and set step_size = 1e-4 for example.

Ideally, we would like to aim for inference to be as automatic as possible, or at least raise warnings so that users know what to do when data/model seems aberrant. Setting the step size upfront is not an option because the actual model is much larger, and it may not work for the remaining model. Likewise, it would be much better for us to start with reasonable mass matrix (somehow), rather than have the user have to guess and provide it.

I am also unsure if it is just an issue with mass matrix initialization, since the step size remains small after the adaptation too. I am still looking into this though, will update this issue if I find something interesting.

It is surprised to me that Stan is fast on this. Given only 1 data, latent l will converge to a value with very small variance. Hence alpha / beta also converges to a value, which implies that they are highly correlated.

If small step_size is a real problem, then I guess we can scale initial_mass_matrix to let step_size have a higher value at the end of get_reasonable_step_size method. Theoretically, things should be the same but I'll check if it works. Maybe scale step_size will help for precision issue or dual_averaging,...

I just run the program and observe that the small step_size happens at the initial window adaptation, not find_reasonable_step_size. I'll further see how Stan deals with this small step_size problem during adapting phase.

I'll further see how Stan deals with this small step_size problem during adapting phase.

The scale of the parameters is so vastly different, that I'm not surprised that we need a small step size, but I would have expected that the mass matrix would help neutralize some of that and make it faster.

My guess is that we need many more slower adaptation windows for mass matrix to be correctly adapted, but in the meantime, we are just stuck with such low step size as to make it impossible to make progress. Stan might have the same issues, but since it has a much faster runtime, it eventually learns a good mass matrix and is not stuck forever. And, PyMC3 probably gets around this problem by constantly adapting the mass matrix.

Was this page helpful?
0 / 5 - 0 ratings