Jax: Bug in jax_debug_nans

Created on 28 Aug 2019  路  11Comments  路  Source: google/jax

Versions 0.1.42+, out_buf referrence before assignment in line 371 of xla.py.

bug

All 11 comments

You need to pass the flag --jax_debug_debug_nans.

Just kidding. Thanks for catching this! Want to make a PR, or want us to handle it?

Thanks @mattjj. I will make a PR asap.

@mattjj I could not find a quick fix for it.

No worries, it's the kind of thing where the context of how this came to be broken is helpful (in short, we revised the core so that jaxprs always have multiple return values rather than possibly returning tuples, and this code wasn't updated properly).

Even more importantly, you've also found that we don't have a test for this code! We should close this issue only once we have a test in place.

Hey, @mattjj. Any news on this issue. I noticed that it hasn't been addressed in the latest release. Thanks!

I think @romanngg fixed this in #1324, though we should add a test so we don't regress it again. I'll close this issue, though! Please ping it if the issue still exists.

@mattjj and @romanngg, the error persists.

~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/random.py in split(key, num)
    191     An array with shape (num, 2) and dtype uint32 representing `num` new keys.
    192   """
--> 193   return _split(key, num)
    194 
    195 @partial(jit, static_argnums=(1,))

~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    145     _check_args(args_flat)
    146     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 147     out = xla.xla_call(flat_fun, *args_flat, device_assignment=device_assignment, backend=backend)
    148     return tree_unflatten(out_tree(), out)
    149 

~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
    567   if top_trace is None:
    568     with new_sublevel():
--> 569       outs = primitive.impl(f, *args, **params)
    570   else:
    571     tracers = map(top_trace.full_raise, args)

~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, *args, **params)
    364   backend = params.get('backend', None)
    365   compiled_fun = _xla_callable(fun, device_assignment, backend,
--> 366                                *map(abstractify, args))
    367   try:
    368     return compiled_fun(*args)

~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/linear_util.py in cached_fun(f, *args)
    215 
    216   def cached_fun(f, *args):
--> 217     ans, f_prev = cached_fun_body(f, args)
    218     if id(f_prev) != id(f):
    219       f.populate_stores(f_prev)

~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/linear_util.py in cached_fun_body(f, args)
    212   @fastcache.clru_cache(maxsize=max_size)
    213   def cached_fun_body(f, args):
--> 214     return call(f, *args), f
    215 
    216   def cached_fun(f, *args):

~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device_assignment, backend, *abstract_args)
    378   pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
    379   with core.new_master(pe.JaxprTrace, True) as master:
--> 380     jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
    381     assert not env  # no subtraces here (though cond might eventually need them)
    382     axis_env = AxisEnv(jaxpr_replicas(jaxpr), [], [])

~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    163 
    164     del gen
--> 165     ans = self.f(*args, **dict(self.params, **kwargs))
    166     del args
    167     while stack:

~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/random.py in _split(key, num)
    196 def _split(key, num):
    197   counts = lax.tie_in(key, lax.iota(onp.uint32, num * 2))
--> 198   return lax.reshape(threefry_2x32(key, counts), (num, 2))
    199 
    200 

~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    145     _check_args(args_flat)
    146     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 147     out = xla.xla_call(flat_fun, *args_flat, device_assignment=device_assignment, backend=backend)
    148     return tree_unflatten(out_tree(), out)
    149 

~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
    570   else:
    571     tracers = map(top_trace.full_raise, args)
--> 572     outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
    573   return apply_todos(env_trace_todo(), outs)
    574 

~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params)
    111     in_pvs, in_consts = unzip2([t.pval for t in tracers])
    112     fun, aux = partial_eval(f, self, in_pvs)
--> 113     out_flat = call_primitive.bind(fun, *in_consts, **params)
    114     out_pvs, jaxpr, env = aux()
    115     out_pv_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])

~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
    567   if top_trace is None:
    568     with new_sublevel():
--> 569       outs = primitive.impl(f, *args, **params)
    570   else:
    571     tracers = map(top_trace.full_raise, args)

~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, *args, **params)
    366                                *map(abstractify, args))
    367   try:
--> 368     return compiled_fun(*args)
    369   except FloatingPointError:
    370     print("Invalid value encountered in the output of a jit function. "

~/.pyenv/versions/ML/lib/python3.6/site-packages/jax/interpreters/xla.py in _execute_compiled(compiled, backend, handlers, *args)
    401   input_bufs = [device_put(x, device_num, backend=backend) for x in args]
    402   out_bufs = compiled.Execute(input_bufs).destructure()
--> 403   if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_buf)
    404   return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)]
    405 

NameError: name 'out_buf' is not defined

Ah, you'll need to check against github master. We haven't updated pypi yet, though I can do that shortly.

Just updated pypi to jax 0.1.45 with this fix.

You are right, @mattjj. I've just rebuilt it from source and it's running under debug_nans flag like a charm. Thanks!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

harshit-2115 picture harshit-2115  路  3Comments

lonelykid picture lonelykid  路  3Comments

clemisch picture clemisch  路  3Comments

fehiepsi picture fehiepsi  路  3Comments

RobertTLange picture RobertTLange  路  3Comments