Pymc3: Broadcasting issue within LOO

Created on 14 Jun 2017  路  6Comments  路  Source: pymc-devs/pymc3

I have written a linear regression with a T likelihood in PyMC3 and am having trouble with the model selection. I am performing some number of separate linear regressions at once and have combined them into a single likelihood.

Upon passing the trace and its model to LOO, I am getting ValueError: operands could not be broadcast together with shapes (3,) (400,) as an error. I have recreated the issue with dummy data, included below, as it persists even with generated data.

import numpy as np
import pymc3 as pm
obs = np.random.standard_t(10, size=(20, 100))
dep_rets = np.random.normal(0, 1, 20)
m, n = obs.shape
with pm.Model() as dummy_model:
    # Intercept prior (variance == sd**2)
    a = pm.Normal('alpha', mu=0, sd=100, shape=n) # uninformative prior
    # Slope prior
    b = pm.Normal('beta', mu=0, sd=100, shape=n) # uninformative prior

    # Degrees of freedom prior
    dof = pm.Gamma('dof', alpha=2, beta=0.1, shape=n)

    # error
    e = pm.HalfNormal('epsilon', sd=1, shape=n)

    # Data likelihood
    dummy_likelihood = pm.StudentT(
        'likelihood',
        nu=dof,
        mu=(a[None, :] + (dep_rets[:, None] * b[None, :])),
        sd=e[None, :],
        shape = n,
        observed=obs
    )
    dummy_trace = pm.sample(2000, njobs=4, tune=500)

pm.stats.loo(dummy_trace, model=dummy_model)

Most helpful comment

I'll do it.

All 6 comments

loo seems to be broken for multidimensional observed variables:

with pm.Model() as model:
    sd = pm.HalfNormal('sd', sd=1)
    pm.Normal('a', sd=sd, shape=(2, 2), observed=np.random.randn(2, 2))
    trace = pm.sample()
    loo = pm.stats.loo(trace)

I believe this can be fixed by just reshaping the logp values in loo, but I have pretty much no idea what is happening in this function, so it would be great if someone else could have a look (git blame shows @twiecki and @fonnesbeck and a few others as authors)

diff --git a/pymc3/stats.py b/pymc3/stats.py
index c6ee4c1a..17834c51 100644
--- a/pymc3/stats.py
+++ b/pymc3/stats.py
@@ -211,6 +211,7 @@ def loo(trace, model=None, pointwise=False):
     model = modelcontext(model)

     log_py = log_post_trace(trace, model)
+    log_py = log_py.reshape((len(log_py), -1))

     # Importance ratios
     r = np.exp(-log_py)

@mmargenot Oh, and thanks for reporting this :-)

CC @aloctavodia

I think @aseyboldt is right, reshaping logp values should fix the problem. Although the fix should be inside log_post_trace() function to avoid reshaping both inside loo() and waic().
As a test I have just compared LOO and WAIC for a couple of cases, and results looks OK.

@aseyboldt are you going to fix this? or should I?

I'll do it.

Was this page helpful?
0 / 5 - 0 ratings