Pymc3: Memory Error: OrderedLogistic

Created on 29 Jun 2019  Â·  19Comments  Â·  Source: pymc-devs/pymc3

I am updating and re-running the statistical rethinking notebooks. I get a memory allocation error with model m_11 (code block 11.5). The problem seems to be related with #3383, reverting changes in categorical distribution previous to that PR, fix the issue. @lucianopaz probably you have a better idea of what is going on here.

Most helpful comment

@seberg was kind enough to point me to take_along_axis which is exactly what we need to replace choose with to fix this issue and not fall into the problems raised in #3564 and #3567. Luckily, take_along_axis is implemented using advanced indexing and we should be able to copy the relevant part to get a theano based implementation to use in Categorical.logp.

All 19 comments

I think the basic example provided in OrderedLogistic (slightly modified to use ADVI) behaves differently in the recent commits.
I've run the following code in the current head and old commits like 9ef2947.

n1_c = 300; n2_c = 300; n3_c = 300
cluster1 = np.random.randn(n1_c) + -1
cluster2 = np.random.randn(n2_c) + 0
cluster3 = np.random.randn(n3_c) + 2

x = np.concatenate((cluster1, cluster2, cluster3))
y = np.concatenate((1*np.ones(n1_c),
                    2*np.ones(n2_c),
                    3*np.ones(n3_c))) - 1

# Ordered logistic regression
with pm.Model() as model:
    cutpoints = pm.Normal("cutpoints", mu=[-1,1], sigma=10, shape=2,
                          transform=pm.distributions.transforms.ordered)
    y_ = pm.OrderedLogistic("y", cutpoints=cutpoints, eta=x, observed=y)
    inference = pm.fit(30000, method='advi')
    tr = inference.sample(1000)

In the older versions, the value of ELBO seems to converge to a reasonable value ~ 10^2, but in the recent commits, the progress bar shows large value ~ 10 ^5.

I suspect this is caused by logp method in Categorical.
In the older versions, when p.ndim == 2 (which is the case in the above example), the log likelihood

a = tt.log(p[tt.arange(p.shape[0]), value_clip])

will be a length p.shape[0] vector.
In recent versions, however, a defined as

pattern = (p.ndim - 1,) + tuple(range(p.ndim - 1))
a = tt.log(p.dimshuffle(pattern)[value_clip])

will be a p.shape[0] x p.shape[0] matrix (which has 900 **2 entries in the above example, roughly matching with the scale of ELBO).

Thanks @tohtsky, I think your diagnosis is right.

@aloctavodia, sorry for the delay in taking a look at this. I'm still mostly offline until next week.

I looked a bit more into the commit you referenced and I think that there may be a bug in Categorical.logp, specifically here. If p is multidimensional, the dimshuffle moves the last axis around to simplify advanced indexing into it with the supplied values. The resulting a will have a weird shape like (value.shape, p.shape[:-1]), instead of having value.shape at the end. This could lead to strange broadcasting with observeds and produce an intermediate array of shape (value.shape, value.shape). So this could cause the memory error you encountered. Maybe that line could be replaced with this:

a = tt.log(p[..., value_clip])

I'll be able to test this in a bit less than two weeks, but if you want to try it out first, you're welcome to do so.

Oops, I skipped over @tohtsky's answer that points to exactly the same line of code that I thought was causing the bug in my previous answer. Maybe we could try with the ellipses in the indexing instead of doing a dimshuffle.

I tried replacing the line you suggested and I still see the same problem :-(

Simply flattening multi-dimensional tensor into a 2d array and then reshaping the resulting logp vector back into the original tensor shape should work?

        if p.ndim > 1:
            original_shape = p.shape[:-1]
            p_flatten = p.reshape((-1, p.shape[-1]))
            a = tt.log(
                p_flatten[
                    (tt.arange(p_flatten.shape[0]), value_clip.ravel())
                ]).reshape(
                    original_shape
            )
        else:
            a = tt.log(p[value_clip])

I'll have time to look into this issue this Friday.

Hi,
Any progress on this issue?

I managed to write a small test involving logp that highlights the error:

import numpy as np
from scipy.special import logit
import pymc3 as pm


loge = np.log10(np.exp(1))
size = 100
p = np.ones(10) / 10
cutpoints = logit(np.linspace(0, 1, 11)[1:-1])
obs = np.random.randint(0, 1, size=size)
with pm.Model():
    ol = pm.OrderedLogistic("ol", eta=0, cutpoints=cutpoints, observed=obs)
    c = pm.Categorical("c", p=p, observed=obs)

print(ol.logp({"ol": 1}) * loge)
print(c.logp({"c": 1}) * loge)

The OrderedLogistic variable ol, given the provided cutpoints, should be equivalent to the categorical RV c. When we look at the returned logp for each RV on a given class, as all are equally likely, the log10 should be approximately equal to -size == -100. This is the case for the Categorical but the OrderedLogistic returns -size**2 == -10000. The problem seems to be in an unnecessary shape padding done in the OrderedLogistic.__init__.

Furthermore,

>>> ol.distribution.p.ndim
2
>>> c.distribution.p.ndim
1

However, there is also an additional shape issue when we give an array of cutpoints, and how this broadcasts with the passed observed. I'll try to finish a fix tomorrow.

In the end, it wasn't a broadcasting problem. It was an indexing problem when p was multidimensional.

tt.choose ended up making more of a mess so we'll have to come up with a different fix for this problem.

@seberg was kind enough to point me to take_along_axis which is exactly what we need to replace choose with to fix this issue and not fall into the problems raised in #3564 and #3567. Luckily, take_along_axis is implemented using advanced indexing and we should be able to copy the relevant part to get a theano based implementation to use in Categorical.logp.

Hi @lucianopaz,
Unfortunately, that PR is still very slow for a multi-dim p

       data = np.random.randint(0, 3, size=(1000, 1))

    with pm.Model() as model:
        tp1 = pm.Dirichlet('tp1', a=np.array([0.25]*4), shape=(4,)) #4 Free RV
        obs = pm.Categorical('obs', p=tp1, observed=data)
        trace = pm.sample() #super fast!

    data_indexer = np.random.randint(0,2,size=(1000,))

    with pm.Model() as model:
        tp1 = pm.Dirichlet('tp1', a=np.array([0.25]*4), shape=(2,4)) #8 Free RV
        obs = pm.Categorical('obs', p=tp1[data_indexer, :], observed=data)
        trace = pm.sample() #takes ages!

Does the second model sample ok for you (just incase i've done something silly with my install)?

@bdyetton, it samples super slow for me too. I'll try to find out why this is happening.

Great thanks !!!

On Wed, Jul 31, 2019 at 11:53 AM Luciano Paz notifications@github.com
wrote:

@bdyetton https://github.com/bdyetton, it samples super slow for me
too. I'll try to find out why this is happening.

—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/pymc-devs/pymc3/issues/3535?email_source=notifications&email_token=ABHSEKBFBLDRRNKFX35BLWLQCFOJ5A5CNFSM4H4K2ZN2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOD3GXJ7I#issuecomment-516781309,
or mute the thread
https://github.com/notifications/unsubscribe-auth/ABHSEKDXMKRAWOPZIQBN7ALQCFOJ5ANCNFSM4H4K2ZNQ
.

--
Regards,
Ben Yetton

MA Cog Psych,
BSc (Robotics) w. first class honors,
Graduate Student,
Mednick Sleep and Cognition Lab,
University Of California, Irvine
[email protected] benjamin.yetton@email.ucr.edu

@lucianopaz I'm not sure if this is helpful at all, but the second model above will not begin sampling with the slice step method, so this is not an issue affecting NUTS only.

Hi @lucianopaz, any progress? Anything I can do to help?

@bdyetton, sorry, I have deal with other stuff from work first. Once I finish, I'll be able to look into this more deeply.

@lucianopaz, Thanks!!!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

sempwn picture sempwn  Â·  3Comments

alessandro-gentilini picture alessandro-gentilini  Â·  5Comments

springcoil picture springcoil  Â·  3Comments

yarlett picture yarlett  Â·  5Comments

Abraxas2071 picture Abraxas2071  Â·  4Comments