Pyro: How to find cause of nan error in ELBO function

Created on 13 Nov 2017  Â·  12Comments  Â·  Source: pyro-ppl/pyro

I am trying to modify the vae.py example, and the ELBO loss function is returning nan before I even do any training steps. My model and guide functions give sensible output, so I have no way to guess where the error is. Thanks for your help.

I wouldn't typically open a github issue for this kind of question, but I am not sure what the best place to get support for pyro is. Thanks!

usability warnings & errors

Most helpful comment

I recommend checking if you have softmax or sigmoids in the loop for
categoricals. They can do that. Also, are you using a Gaussian? Is your std
constrained to be positive?

On Tue, Nov 14, 2017, 9:18 AM Noah D Brenowitz notifications@github.com
wrote:

@jpchen https://github.com/jpchen I am trying to debug a case when a
single evaluation of svi.loss gives an error. I have been carefully
stepping through with pdb, and just discovered that the problem is
probably with my model . A nan warning at
pyro/infer/trace_elbo.py(85)_get_traces() would have saved me a bunch of
time, and could at least have told me whether the guide or model is wrong.

—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub
https://github.com/uber/pyro/issues/571#issuecomment-344331471, or mute
the thread
https://github.com/notifications/unsubscribe-auth/ABVhL1wJONnGENwWl8XxxFoSUnC-Qhjnks5s2ctzgaJpZM4Qcgrb
.

All 12 comments

I'm not sure the best place to debug NaNs, but here's what I'd try:

  1. Insert checks like assert not np.isnan(my_tensor.sum()) into Pyro code.
  2. Try using advance Pyro features like
    py tr = poutine.trace(guide).get_trace() tr.log_pdf() print('\n'.join('{}: {}'.format(name, site['log_pdf']) for name, site in tr.nodes.items()))

If you do find a good debugging trick, we'd be happy to add it to Pyro for error checking.

(And yes this is the right place for questions. We're in the process of seting up a gitter and a discuss.)

to add to some of the suggestions:
1) revert your changes to get a working state of the vae, then modify one change at a time and put debugger statements in your guide and model to see exactly where the occurrence of the nan is.
2) you can inspect the parameters directly in the param store via pyro.get_param_store() - you can step through and inspect the computed gradients before a step is taken. the TraceELBO class also provides a loss() method you can call to inspect the ELBo without taking a step. make sure those values are what youd expect
unfortunately, as @fritzo alludes to, debugging NaNs is a very ad-hoc process that even the most experienced researchers often get plagued with.

@jpchen I am trying to debug a case when a single evaluation of svi.loss gives an error. I have been carefully stepping through with pdb, and just discovered that the problem is probably with my model . A nan warning at pyro/infer/trace_elbo.py(85)_get_traces() would have saved me a bunch of time, and could at least have told me whether the guide or model is wrong.

I recommend checking if you have softmax or sigmoids in the loop for
categoricals. They can do that. Also, are you using a Gaussian? Is your std
constrained to be positive?

On Tue, Nov 14, 2017, 9:18 AM Noah D Brenowitz notifications@github.com
wrote:

@jpchen https://github.com/jpchen I am trying to debug a case when a
single evaluation of svi.loss gives an error. I have been carefully
stepping through with pdb, and just discovered that the problem is
probably with my model . A nan warning at
pyro/infer/trace_elbo.py(85)_get_traces() would have saved me a bunch of
time, and could at least have told me whether the guide or model is wrong.

—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub
https://github.com/uber/pyro/issues/571#issuecomment-344331471, or mute
the thread
https://github.com/notifications/unsubscribe-auth/ABVhL1wJONnGENwWl8XxxFoSUnC-Qhjnks5s2ctzgaJpZM4Qcgrb
.

I had a really silly error. I had the order of arguments in pyro.observe mixed up. I still think an intelligently worded warning message indicating that my problem was in the model would have saved me a lot of time.

@nbren12 Could you paste a diff of your error and fix? Something like

- pyro.observe("x", dist.categorical, ps, x)
+ pyro.observe("x", dist.categorical, x, ps)

It's a little embarrassing:

-        return pyro.observe("obs", dist.normal, mu_img, sig_img, y)
+        return pyro.observe("obs", dist.normal, y, mu_img, sig_img)

@nbren12 Only embarassing for the interface designers :wink: We're currently working on error reporting for the distributions library, and your example helps us know where the sharp edges are. We'll try to add a warning for nonpositive sigma parameter to dist.normal.

This was also the same issue with #392, as we were discussing. There were two solutions that were discussed - deprecate pyro.observe in favor of pyro.sample, or change the arguments to pyro.observe so that the observation is specified first, followed by the function and its arguments.

i still think we should deprecate pyro.observe!

@nbren12 a few of us on the development team have been bitten by this as well. a way you could avoid this (and might be perhaps more intuitive to read) is to write the distribution as an object, ie:

pyro.observe("obs", dist.Normal(mu, sigma), observation)

change the arguments to pyro.observe so that the observation is specified first

im wary of doing this since this will break code for all our existing users

@jpchen I do like that syntax better since it seems more explicit. That said, I feel like this problem is probably a good example of saying "fool me once...". I may just use pyro.sample in the future.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

neerajprad picture neerajprad  Â·  4Comments

fehiepsi picture fehiepsi  Â·  3Comments

jpchen picture jpchen  Â·  5Comments

fritzo picture fritzo  Â·  5Comments

lundlab-kaltinel picture lundlab-kaltinel  Â·  3Comments