Many of our JIT tests in HMC / NUTS. e.g. test_gamma_normal, test_dirichlet_categorical, etc. are not terminating on the pytorch-1.0 branch. The reason for this is that the step size is becoming extremely small in the adaptation phase, and we end up crawling to a halt. Note that the test runs fine without JIT. This probably points to some incompatibility in the adaptation code w.r.t. the assumptions that the JIT is making. This is likely a regression due to some of our internal changes / refactoring of the code, since these tests were working earlier for the most part. cc. @fehiepsi
@neerajprad I would like to resolve this bug to learn a bit about jit (I have no idea how it works at the moment ^___^!). Could you let me know how to do these tests with jit?
@fehiepsi - Great! You will need to checkout the pytorch-1.0 branch to get all the JIT related changes (see the corresponding PR #1431). You will also need to be on PyTorch master (or download the nightly build using conda install torch_nightly -c pytorch, though it might have a few perf issues that are being worked on). You can run all jit tests using make test-jit. Since that takes a bit of time, for the purpose of debugging, I would suggest running the JIT tests directly from test_hmc.py and test_nuts.py.
All that does, is JIT the potential energy computation when run the first time, and use the compiled version subsequently. You can toggle off the JIT warnings by setting ignore_jit_warnings=True, which should be the case for some of these tests. Let me know if you face any issues.
@neerajprad Thanks for your explanation! Now I understand the purpose of JIT.
I tried to debug and saw that the _potential_energy_jit gives different derivatives comparing to no jit. I checked the input of _potential_grad function in integrator.py script. Given z, jit and nojit give the same potential_energy but different grads.
Given z, jit and nojit give the same potential_energy but different grads.
Thanks for debugging, @fehiepsi. It seems that the autograd graph constructed by JIT is not correct for some reason. Is it the case for all tests or only some tests? If the grads are different only for certain tests, I suspect the culprit might be some distributions' log_prob methods. This might get tricky to debug, I will also take a look at it later this week.
You are right that JIT grad is just incorrect for gamma and dirichlet models. Other models look fine. This seems like a PyTorch bug.
@neerajprad I ping a bug I notice in slack. You might check it there. :)
@neerajprad I think that I catch the bug now. Please take a look at the notebook https://gist.github.com/fehiepsi/e2cc69bfaa9b00033834756b3092970f
It seems that we have a bug when using gamma + trace + jit.
cc @fritzo
Thanks for digging in, @fehiepsi. Taking a look!
Thanks, @fehiepsi. Moving the discussion from the gist so that I can get notified.
The following snippet from your example has the same issue, but I have not been able to figure out a smaller example that is independent of poutine.trace. There might be something inside of our poutine internals that doesn't behave well with JIT. It is still a bit concerning that the JIT does not raise any warnings (maybe it is expected because check_trace only checks for consistency in the output of the JIT and python functions, not the gradients).
data = dist.Normal(3, 0.5).sample(torch.Size([1000]))
def model():
z = pyro.sample("z", dist.Gamma(1, 1))
pyro.sample("obs", dist.Normal(3, z), obs=data)
def fn(z):
trace = poutine.trace(poutine.condition(model, {"z": z})).get_trace()
return trace.nodes["z"]["fn"].log_prob(z).sum() + trace.nodes["obs"]["fn"].log_prob(data).sum()
z = torch.tensor(1., requires_grad=True)
fn_jit = torch.jit.trace(fn, (z,))
print(fn(z))
print(fn_jit(z))
print(autograd.grad(fn(z), (z,)))
print(autograd.grad(fn_jit(z), (z,))) # Same issue
Are you sure it's OK to close fn over data in that example? trace and condition are not doing any PyTorch operations and should be completely invisible to the JIT.
Some observations:
log_prob at z, then grad is right. If we return log_prob at obs, things are still right. But if we return sum of them (or any linear combination of them), then wrong grad happens.log_prob_at_z + 0 * log_prob_at_obs, then grad will increased by 1000 times. This is the number of data points in obs..expand. But it seems the problem has been solved in pytorch master (no issue with jit of fn2 in my gist).Are you sure it's OK to close fn over data in that example?
We have the same issue with data inside fn.
trace and condition are not doing any PyTorch operations and should be completely invisible to the JIT.
You are right. I am pretty sure we should be able to reconstruct a failing example without the poutine code, and by just using dicts, but I have not been able to recreate a more minimal example yet.
this just happens with some distributions such as Gamma, Dirichlet. Other distributions seem fine.
That's interesting - maybe there is some bug in the autograd graph wrt some operation inside of our poutine code for these distributions, but not others.
@neerajprad @eb8680 This is not a pyro problem. I can replicate the bug in PyTorch
import torch
import torch.autograd as autograd
import torch.distributions as dist
def fn(z):
a = dist.Gamma(1, 1)
return a.log_prob(z).sum() + (z.log() - data).sum()
data = torch.zeros(1000)
z = torch.tensor(1., requires_grad=True)
fn_jit = torch.jit.trace(fn, (z,))
print(fn(z)) # -1
print(fn_jit(z)) # -1
print(autograd.grad(fn(z), (z,))) # return 999
print(autograd.grad(fn_jit(z), (z,))) # return 0
I have raised the bug in pytorch slack.
The bug is reported at: https://github.com/pytorch/pytorch/issues/13669
Thanks @fehiepsi. Let us keep the issue open so that we can uncomment the tests once the upstream issue is resolved, and verify the fix.
Most helpful comment
@neerajprad @eb8680 This is not a pyro problem. I can replicate the bug in PyTorch
I have raised the bug in pytorch slack.