Pymc3: computation of logp_elemwise is slower than before

Created on 5 Aug 2017  路  6Comments  路  Source: pymc-devs/pymc3

When working on the compare() function I noticed that computation of WAIC is now super-slow (compared to the last time I tried).

The problem is with the computation of logp_elemwise. I just run the following minimal example

x_obs = np.random.normal(0, 1, size=100)
with pm.Model() as model:
    mu = pm.Normal('mu', 0, 1)
    x = pm.Normal('x', mu=mu, sd=1, observed=x_obs)

The following test takes around 80 ms (using current master) and less than 1 ms using the latest release (3.1 final)

%timeit x.logp_elemwise({'mu': 0.})

Most helpful comment

Thanks everyone for your help! Changes are implemented in #2479 WAIC computations are really fast now (at least 100 x)

All 6 comments

The slow down seems to related to the NUTS PR https://github.com/pymc-devs/pymc3/pull/2345.

@aseyboldt Could you please have a look?

Ah, I think this is because I removed a couple of memoize calls in Factor: https://github.com/pymc-devs/pymc3/pull/2345/files#diff-1f25198dab6a35cd27fffe043b7e1b9dL598

The problem with those is that they are used in Model, too. And models aren't static, so a simple memoize will lead to wrong results if the model changes after the first time one of the attributes is used. This was one of the reasons for the design of model.logp_dlogp_function, it is a function that is called once, and returns a function that is independent of the underlying model.

You can get around that atm with doing this:

%timeit x.logp_elemwise({mu: 0.1})
logp = x.logp_elemwise
%timeit logp({mu: 0.1})
49.3 ms 卤 1.84 ms per loop (mean 卤 std. dev. of 7 runs, 10 loops each)
22.7 碌s 卤 290 ns per loop (mean 卤 std. dev. of 7 runs, 10000 loops each)

I believe all that caching was a source of slowdown before, too, as 1ms isn't a reasonable time either. When storing the theano function locally you get 22渭s on my machine.

That's great. I missed it

I see, thanks!

@aloctavodia could you please change the way logp_elemwise is called (i guess it's in the function _log_post_trace in stats.py?) and see if it improves?

Thanks everyone for your help! Changes are implemented in #2479 WAIC computations are really fast now (at least 100 x)

@aloctavodia Thanks for this! I noticed the same thing but couldn't track down the source.

Was this page helpful?
0 / 5 - 0 ratings