These tests are flaky and slow and are impeding development, e.g.
FAILED tests/infer/mcmc/test_nuts.py::test_nuts_conjugate_gaussian[dim=5_chain-len=9_num_obs=1]
FAILED tests/infer/mcmc/test_hmc.py::test_hmc_conjugate_gaussian[dim=10_chain-len=4_num_obs=1]
Can we replace some of these with less flaky tests?
Is the issue that these tests are giving non-deterministic results (we had a few instances of this, but that has been resolved), or that they are extremely sensitive to small changes in the algorithm? I think at this point there is very little signal that we get from these conjugate gaussian tests, so I would be inclined to skip them altogether if I am unable to make them faster. They are still useful for profiling and checking if inference has improved due to a better chosen mass matrix etc., so we can keep them around for local testing.
@neerajprad FYI, pyro NUTS is 20x slower than pymc3 NUTS. This is unacceptable :(
For two models, depth of the tree is quite similar. So I guess time for energy and its grad is the main bottleneck. Let me investigate more.
I updated pytorch to lastest version and turned on jit_compile. Things are faster. Now, pyro is 4x slower than pymc.
You can get profiles file here: https://1drv.ms/f/s!AiPrcBQmpf6EgSW6A-7-QD77Y5Ax
Here are something I observed:
potential_jit is 4x faster than potential but grad is not improved much. Can we compile backward pass too?@fehiepsi can you clarify what models you're using and what you mean by "compiling in pymc" and "without compiling, pyro is a little bit faster than pymc"? Are you also saying your model in Pyro is slower with the JIT than without in the latest version of PyTorch?
@eb8680 FYI, here are scripts to replicate results in pyro (1.0 branch) and pymc3:
# Source: Statistical Rethinking with Python and PyMC3
import numpy as np
import pandas as pd
import theano
# uncomment the following line to disable jit
#theano.config.mode = 'FAST_COMPILE'
import pymc3 as pm
np.random.seed(0)
rugged_df = (pd.read_csv('rugged.csv', sep=';')
.assign(log_gdp=lambda df: np.log(df.rgdppc_2000))
.dropna(subset=['log_gdp']))
with pm.Model() as model:
a = pm.Normal('a', 0., 100.)
bR = pm.Normal('bR', 0., 10.)
bA = pm.Normal('bA', 0., 10.)
bAR = pm.Normal('bAR', 0., 10.)
mu = (a + bR * rugged_df.rugged + bA * rugged_df.cont_africa
+ bAR * rugged_df.rugged * rugged_df.cont_africa)
sigma = pm.HalfCauchy('sigma', 2.)
log_gdp = pm.Normal('log_gdp', mu, sigma, observed=rugged_df.log_gdp)
with model:
trace_8_1 = pm.sample(1000, tune=1000, chains=1)
import math
import pandas as pd
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import EmpiricalMarginal
from pyro.infer.mcmc import MCMC, NUTS
pyro.set_rng_seed(0)
rugged = pd.read_csv("rugged.csv", sep=";")
d = rugged
d["log_gdp"] = d["rgdppc_2000"].apply(math.log)
dd = d[d["rgdppc_2000"].notnull()]
dd_rugged = torch.tensor(dd["rugged"].values, dtype=torch.float)
dd_cont_africa = torch.tensor(dd["cont_africa"].values, dtype=torch.float)
dd_log_gdp = torch.tensor(dd["log_gdp"].values, dtype=torch.float)
def m8_1stan_model(rugged, cont_africa, log_gdp):
a = pyro.sample("a", dist.Normal(0, 100))
bR = pyro.sample("bR", dist.Normal(0, 10))
bA = pyro.sample("bA", dist.Normal(0, 10))
bAR = pyro.sample("bAR", dist.Normal(0, 10))
mu = a + bR * rugged + bA * cont_africa + bAR * rugged * cont_africa
sigma = pyro.sample("sigma", dist.HalfCauchy(2))
with pyro.plate("plate"):
pyro.sample("log_gdp", dist.Normal(mu, sigma), obs=log_gdp)
num_samples = 1000
warmup_steps = 1000
num_chains = 1
nuts_kernel = NUTS(m8_1stan_model, jit_compile=False) # change to True to use jit
posterior = MCMC(nuts_kernel,
num_samples=num_samples,
warmup_steps=warmup_steps,
num_chains=num_chains)
posterior.run(rugged=dd_rugged, cont_africa=dd_cont_africa, log_gdp=dd_log_gdp)
Link to the data
PyMC3 uses Theano to compile the code under the hood. With theano.config.mode = 'FAST_COMPILE', Theano will use Python implementation instead (with some small optimization). What I mean is Pyro (with jit_compile=False) is a little bit faster than PyMC3 (with FAST_COMPILE).
When we compile potential energy, time spent for forward pass is reduced from 80s to 15s, which is a big improvement (the time for backward pass is reduced from 25s to 18s). But for pymc3, it is reduced from 134s to 4s (including computing grad)!
Note that my profiles are run on a pretty old laptop without gpu. For further justification, I think that it is better to run on both CPU/GPU in a modern machine.
Great work, and thanks for starting to dig into the performance aspects of NUTS! I will take a closer look at your benchmarks.
Without compiling, pyro is a little bit faster than pymc.
I would have thought the same since there are few opportunities for making further optimizations in our python code, and for the models that I have seen, most of the time taken is in PyTorch. But it is great to know that we are actually a bit faster than pymc in this regard. :)
With compiling, pyro is 4x slower than pymc (compiling in pymc gives it 8x improvement). This suggests that theano does better job in compiling (small models) than pytorch.
potential_jit is 4x faster than potential but grad is not improved much.
This might improve with any improvements in trace.jit, but I doubt if the performance gap will close by much just with the JIT tracer. 4X isn't that bad though. :)
Can we compile backward pass too?
torch.jit.script will be able to compile the backward pass too but I am not sure if that's slated for the 1.0 release. I would expect that models compiled with torch.jit.script should be faster, and we might be able to close the gap if we are able to JIT the integrator step itself (we may need to use a combination of jit.trace and jit.script though).
we might be able to close the gap if we are able to JIT the integrator step itself (we may need to use a combination of jit.trace and jit.script though).
yes, totally agree! (and some other methods such as kinetic_energy and is_turning too). It seems that jit.script allows OrderDict, so we might use OrderDict instead of dict (I am not so sure, I get problems while playing around with jit.script).
@fritzo @neerajprad Are the slowness and flakiness fixed? If not, I鈥檒l take a look. We might just change random seed, or num steps, step size though.
@fehiepsi - These tests have become even slower on pytorch-1.0 (see - https://github.com/pytorch/pytorch/issues/12190) and I am skipping them (I think these are more useful for benchmarking locally than running every time on CI).
This is closed because Neeraj already marked skip for these tests.