Pyro: MCMC does not work well with multi-chains in CPU due to Memory Error

Created on 1 Feb 2019  路  15Comments  路  Source: pyro-ppl/pyro

The following script is copied exactly from Bayesian regression tutorial (except num_samples=3000 and num_chains=2). While running it, I get MemoryError. I have tried to debug this memory error by changing various settings (ulimit, shared memory segment) in my system but no hope. I get this issue in two systems which I have so I guess this will not happen for only me. If so, then this is a serious problem which can come from how we use torch.multiprocessing in Pyro.

import numpy as np
import pandas as pd
import torch

import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import MCMC, NUTS

pyro.enable_validation(True)
pyro.set_rng_seed(1)
DATA_URL = "https://d2fefpcigoriu7.cloudfront.net/datasets/rugged_data.csv"
rugged_data = pd.read_csv(DATA_URL, encoding="ISO-8859-1")

#torch.multiprocessing.set_sharing_strategy("file_system")


def model(is_cont_africa, ruggedness, log_gdp):
    a = pyro.sample("a", dist.Normal(8., 1000.))
    b_a = pyro.sample("bA", dist.Normal(0., 1.))
    b_r = pyro.sample("bR", dist.Normal(0., 1.))
    b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
    with pyro.iarange("data", len(ruggedness)):
        pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)


df = rugged_data[["cont_africa", "rugged", "rgdppc_2000"]]
df = df[np.isfinite(df.rgdppc_2000)]
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])
train = torch.tensor(df.values, dtype=torch.float)
is_cont_africa, ruggedness, log_gdp = train[:, 0], train[:, 1], train[:, 2]

nuts_kernel = NUTS(model, adapt_step_size=True)
hmc_posterior = MCMC(nuts_kernel, num_samples=4000, warmup_steps=1000, num_chains=4).run(is_cont_africa, ruggedness, log_gdp)

Here is its backtrack

bug

All 15 comments

Could you also mention the pytorch version and OS? Does the memory keep increasing leading to the crash?

@neerajprad The following system is tested:

  • Python: python 3.6, 3.7
  • OS: ubuntu 16.04, 18.04
  • PyTorch: 0.4.1 (with old version of pyro), 1.0, nightly, pytorch-cpu, pytorch

Yes, memory keeps increasing, but I think it is a normal behaviour with MCMC? Shared memory increases to about 400MB, then the error happens. Did you have the above problem?

I don't see this on mac, but I can replicate it on my linux system. The shared memory shouldn't really increase because the queue is getting cleared in the main process.

I don't see this on mac, but I can replicate it on my linux system. The shared memory shouldn't really increase because the queue is getting cleared in the main process.

Interesting! Good to know that it is a system issue. I have tried so many ways to detect and resolve that error. Your information is the best one I have. Thanks @neerajprad !

I think we are hitting some system level limitations on shared memory. If you change the _ParallelSampler to yield None, args[1].. instead of the trace, you will find that no error is thrown. I think what is happening is that even though we read from the queue, the objects remain in shared memory and after a certain point we hit some limitation due to which the OS ends up killing the process.

One thing would be to find out what limitation are we hitting into and increase the limit in sysctl.conf. But I think we shouldn't do that. I think this issue will resolve once #1725 is resolved, since we won't be keeping the objects in shared memory but reducing it to some other representation alongside. This will ensure that our shared memory resource consumption remains fixed during the entire sampling.

I have tried to increase various system configurations but no hope. Hope that this will be resolved when #1725 is resolved.

Btw, I found an interesting thing during the way: using file_system (in torch.multiprocessing.set_sharing_strategy) speeds up inference a lot (1.5x - 3x) over file_descriptor strategy (the default one). :)

using file_system (in torch.multiprocessing.set_sharing_strategy) speeds up inference a lot (1.5x - 3x) over file_descriptor strategy (the default one). :)

Do you see this speed-up on this example too? I didn't see anything noticeable.

@neerajprad , for the above example (with jit_compile=True, num_samples=500 to avoid Memory error, warmup_steps to 20 to reduce its effect on timing):

  • num_chains=2: 2s vs 5s. (averaging for all chains: 180its/s vs 90its/s)
  • num_chains=4: 3s vs 10s. (160its/s vs 50its/s)
  • num_chains=6: 5s vs 15s. (100its/s vs 33its/s)

@fehiepsi - This seems to work fine now. Could you check this again on your system and close this issue if it is resolved?

Yup, it works nicely now. :)

@neerajprad I still got this problem with higher number of samples and num_chains=4.

Does your proposed solution from the forum address this issue?

Yes, using the new api allows me getting higher num_samples, but the root problem still remains there. The solution in forum fixes this issue.

@neerajprad I made some benchmarks to see if there are regressions using clone operator.

import torch
torch.set_default_tensor_type(torch.cuda.FloatTensor)

params = {}
for i in range(100):
    params['x{}'.format(i)] = torch.randn(100)

%%time
for n in range(10000):
    cloned_params = {}
    for k, v in params.items():
        cloned_params[k] = v.clone()

It took 8.9s to generate 10000 samples (each sample has 100 sites and each site is a 100-D tensor). So the overhead is somehow large (but small comparing to sampling) when there are many sites are involved.

If checkpoint solution is not elegant (at least not elegant to me because how to specify checkpoint value depending on system/platform), we may tell MCMC only generate flatten values. The overhead will be small

params_flat = torch.cat(list(params.values()))

%%time
for n in range(10000):
    cloned_params = params_flat.clone()

took 100ms. After done sampling, we'll just need to populate stacked flatten samples to stacked (dict type) samples, which should be cheap. I'll make a PR to address this and benchmark carefully.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

neerajprad picture neerajprad  路  4Comments

robsalomone picture robsalomone  路  4Comments

tobyclh picture tobyclh  路  3Comments

fehiepsi picture fehiepsi  路  3Comments

neerajprad picture neerajprad  路  4Comments