Versions 0.1.42+, out_buf referrence before assignment in line 371 of xla.py.
You need to pass the flag --jax_debug_debug_nans.
Just kidding. Thanks for catching this! Want to make a PR, or want us to handle it?
Thanks @mattjj. I will make a PR asap.
@mattjj I could not find a quick fix for it.
No worries, it's the kind of thing where the context of how this came to be broken is helpful (in short, we revised the core so that jaxprs always have multiple return values rather than possibly returning tuples, and this code wasn't updated properly).
Even more importantly, you've also found that we don't have a test for this code! We should close this issue only once we have a test in place.
Hey, @mattjj. Any news on this issue. I noticed that it hasn't been addressed in the latest release. Thanks!
This is likely related
https://github.com/google/jax/pull/1324
I think @romanngg fixed this in #1324, though we should add a test so we don't regress it again. I'll close this issue, though! Please ping it if the issue still exists.
@mattjj and @romanngg, the error persists.
~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/random.py in split(key, num)
191 An array with shape (num, 2) and dtype uint32 representing `num` new keys.
192 """
--> 193 return _split(key, num)
194
195 @partial(jit, static_argnums=(1,))
~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/api.py in f_jitted(*args, **kwargs)
145 _check_args(args_flat)
146 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 147 out = xla.xla_call(flat_fun, *args_flat, device_assignment=device_assignment, backend=backend)
148 return tree_unflatten(out_tree(), out)
149
~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
567 if top_trace is None:
568 with new_sublevel():
--> 569 outs = primitive.impl(f, *args, **params)
570 else:
571 tracers = map(top_trace.full_raise, args)
~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, *args, **params)
364 backend = params.get('backend', None)
365 compiled_fun = _xla_callable(fun, device_assignment, backend,
--> 366 *map(abstractify, args))
367 try:
368 return compiled_fun(*args)
~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/linear_util.py in cached_fun(f, *args)
215
216 def cached_fun(f, *args):
--> 217 ans, f_prev = cached_fun_body(f, args)
218 if id(f_prev) != id(f):
219 f.populate_stores(f_prev)
~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/linear_util.py in cached_fun_body(f, args)
212 @fastcache.clru_cache(maxsize=max_size)
213 def cached_fun_body(f, args):
--> 214 return call(f, *args), f
215
216 def cached_fun(f, *args):
~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device_assignment, backend, *abstract_args)
378 pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
379 with core.new_master(pe.JaxprTrace, True) as master:
--> 380 jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
381 assert not env # no subtraces here (though cond might eventually need them)
382 axis_env = AxisEnv(jaxpr_replicas(jaxpr), [], [])
~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
163
164 del gen
--> 165 ans = self.f(*args, **dict(self.params, **kwargs))
166 del args
167 while stack:
~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/random.py in _split(key, num)
196 def _split(key, num):
197 counts = lax.tie_in(key, lax.iota(onp.uint32, num * 2))
--> 198 return lax.reshape(threefry_2x32(key, counts), (num, 2))
199
200
~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/api.py in f_jitted(*args, **kwargs)
145 _check_args(args_flat)
146 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 147 out = xla.xla_call(flat_fun, *args_flat, device_assignment=device_assignment, backend=backend)
148 return tree_unflatten(out_tree(), out)
149
~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
570 else:
571 tracers = map(top_trace.full_raise, args)
--> 572 outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
573 return apply_todos(env_trace_todo(), outs)
574
~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params)
111 in_pvs, in_consts = unzip2([t.pval for t in tracers])
112 fun, aux = partial_eval(f, self, in_pvs)
--> 113 out_flat = call_primitive.bind(fun, *in_consts, **params)
114 out_pvs, jaxpr, env = aux()
115 out_pv_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
567 if top_trace is None:
568 with new_sublevel():
--> 569 outs = primitive.impl(f, *args, **params)
570 else:
571 tracers = map(top_trace.full_raise, args)
~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, *args, **params)
366 *map(abstractify, args))
367 try:
--> 368 return compiled_fun(*args)
369 except FloatingPointError:
370 print("Invalid value encountered in the output of a jit function. "
~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/interpreters/xla.py in _execute_compiled(compiled, backend, handlers, *args)
401 input_bufs = [device_put(x, device_num, backend=backend) for x in args]
402 out_bufs = compiled.Execute(input_bufs).destructure()
--> 403 if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_buf)
404 return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)]
405
NameError: name 'out_buf' is not defined
Ah, you'll need to check against github master. We haven't updated pypi yet, though I can do that shortly.
Just updated pypi to jax 0.1.45 with this fix.
You are right, @mattjj. I've just rebuilt it from source and it's running under debug_nans flag like a charm. Thanks!