I am not sure if this is a bug or I am using "pyro.plate" in a wrong way.
The following code shows the issue. I got this issue while trying to train a HMM model. Things work well for non-jit version. If we call torch.autograd.grad(loss_jit(x), (x,)) two times, it will throw RuntimeError: "Can not go backward 2 times for the graph".
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import torch
def model():
with pyro.plate("plate", 2):
pyro.sample("x", dist.Normal(0, 1))
#pyro.sample("x", dist.Normal(0, 1).expand((2,)).to_event()) # this works!
trace = poutine.trace(model).get_trace()
def loss(x):
trace.nodes["x"]["value"] = x
return poutine.trace(poutine.replay(model, trace=trace)).get_trace().log_prob_sum()
x = torch.tensor([0., 1.], requires_grad=True)
loss_jit = torch.jit.trace(loss, x, check_trace=False)
print(loss_jit(x)) # -2.3379
print(loss_jit(x)) # -4.6758
This doesn't seem related to plate per se, but rather the fact that plate generates a sample site whose log_prob method returns a constant. This loss generates identical behavior:
def loss(x):
res = 0.
res += torch.tensor(0.)
res += torch.distributions.Normal(0, 1).expand((2,)).log_prob(x).sum()
return res
# this will fail silently, like the snippet above
loss_jit = torch.jit.trace(loss, x, check_trace=False)
# this will fail with a compilation error, unlike the snippet above even with check_trace=True
loss_jit = torch.jit.trace(loss, x, check_trace=True)
# this will work as expected
loss_jit = torch.jit.trace(loss, x, check_trace=False, optimize=False)
# this will also work as expected
loss_jit = torch.jit.trace(loss, x, check_trace=True, optimize=True, _force_outplace=True)
The root cause seems to be a bug in the JIT's optimizer, which is turned on by default - if you turn the optimizer off by passing optimize=False to torch.jit.trace, the error goes away.
There's an in-place add at this line in Trace.log_prob_sum: https://github.com/uber/pyro/blob/dev/pyro/poutine/trace_struct.py#L149. If you replace that with result = result + log_p, the in-place instruction in the generated code and the errors both go away even with the optimizer turned on. There also seems to be an undocumented switch _force_outplace=True, disabled by default, that prevents the JIT from generating in-place ops even without changing the line in log_prob_sum.
Excellent! Thanks a lot @eb8680 for a clear explanation! Things work smoothly now. :)
Nice sleuthing, @eb8680! We should fix that line in Trace.log_prob_sum() to use result = result + _ rather than +=.
Closed via #1687.
Most helpful comment
This doesn't seem related to
plateper se, but rather the fact thatplategenerates a sample site whoselog_probmethod returns a constant. Thislossgenerates identical behavior:The root cause seems to be a bug in the JIT's optimizer, which is turned on by default - if you turn the optimizer off by passing
optimize=Falsetotorch.jit.trace, the error goes away.There's an in-place add at this line in
Trace.log_prob_sum: https://github.com/uber/pyro/blob/dev/pyro/poutine/trace_struct.py#L149. If you replace that withresult = result + log_p, the in-place instruction in the generated code and the errors both go away even with the optimizer turned on. There also seems to be an undocumented switch_force_outplace=True, disabled by default, that prevents the JIT from generating in-place ops even without changing the line inlog_prob_sum.