Pyro: JIT trace does not work correctly with pyro.plate

Created on 28 Dec 2018  路  4Comments  路  Source: pyro-ppl/pyro

I am not sure if this is a bug or I am using "pyro.plate" in a wrong way.

The following code shows the issue. I got this issue while trying to train a HMM model. Things work well for non-jit version. If we call torch.autograd.grad(loss_jit(x), (x,)) two times, it will throw RuntimeError: "Can not go backward 2 times for the graph".

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import torch

def model():
    with pyro.plate("plate", 2):
        pyro.sample("x", dist.Normal(0, 1))
    #pyro.sample("x", dist.Normal(0, 1).expand((2,)).to_event())  # this works!

trace = poutine.trace(model).get_trace()

def loss(x):
    trace.nodes["x"]["value"] = x
    return poutine.trace(poutine.replay(model, trace=trace)).get_trace().log_prob_sum()

x = torch.tensor([0., 1.], requires_grad=True)
loss_jit = torch.jit.trace(loss, x, check_trace=False)
print(loss_jit(x))  # -2.3379
print(loss_jit(x))  # -4.6758
jit

Most helpful comment

This doesn't seem related to plate per se, but rather the fact that plate generates a sample site whose log_prob method returns a constant. This loss generates identical behavior:

def loss(x):
    res = 0.
    res += torch.tensor(0.)
    res += torch.distributions.Normal(0, 1).expand((2,)).log_prob(x).sum()
    return res

# this will fail silently, like the snippet above
loss_jit = torch.jit.trace(loss, x, check_trace=False)

# this will fail with a compilation error, unlike the snippet above even with check_trace=True
loss_jit = torch.jit.trace(loss, x, check_trace=True)

# this will work as expected
loss_jit = torch.jit.trace(loss, x, check_trace=False, optimize=False)

# this will also work as expected
loss_jit = torch.jit.trace(loss, x, check_trace=True, optimize=True, _force_outplace=True)

The root cause seems to be a bug in the JIT's optimizer, which is turned on by default - if you turn the optimizer off by passing optimize=False to torch.jit.trace, the error goes away.

There's an in-place add at this line in Trace.log_prob_sum: https://github.com/uber/pyro/blob/dev/pyro/poutine/trace_struct.py#L149. If you replace that with result = result + log_p, the in-place instruction in the generated code and the errors both go away even with the optimizer turned on. There also seems to be an undocumented switch _force_outplace=True, disabled by default, that prevents the JIT from generating in-place ops even without changing the line in log_prob_sum.

All 4 comments

This doesn't seem related to plate per se, but rather the fact that plate generates a sample site whose log_prob method returns a constant. This loss generates identical behavior:

def loss(x):
    res = 0.
    res += torch.tensor(0.)
    res += torch.distributions.Normal(0, 1).expand((2,)).log_prob(x).sum()
    return res

# this will fail silently, like the snippet above
loss_jit = torch.jit.trace(loss, x, check_trace=False)

# this will fail with a compilation error, unlike the snippet above even with check_trace=True
loss_jit = torch.jit.trace(loss, x, check_trace=True)

# this will work as expected
loss_jit = torch.jit.trace(loss, x, check_trace=False, optimize=False)

# this will also work as expected
loss_jit = torch.jit.trace(loss, x, check_trace=True, optimize=True, _force_outplace=True)

The root cause seems to be a bug in the JIT's optimizer, which is turned on by default - if you turn the optimizer off by passing optimize=False to torch.jit.trace, the error goes away.

There's an in-place add at this line in Trace.log_prob_sum: https://github.com/uber/pyro/blob/dev/pyro/poutine/trace_struct.py#L149. If you replace that with result = result + log_p, the in-place instruction in the generated code and the errors both go away even with the optimizer turned on. There also seems to be an undocumented switch _force_outplace=True, disabled by default, that prevents the JIT from generating in-place ops even without changing the line in log_prob_sum.

Excellent! Thanks a lot @eb8680 for a clear explanation! Things work smoothly now. :)

Nice sleuthing, @eb8680! We should fix that line in Trace.log_prob_sum() to use result = result + _ rather than +=.

Closed via #1687.

Was this page helpful?
0 / 5 - 0 ratings