Jax: jit compiled closures in fori loop

Created on 8 Jan 2019  路  11Comments  路  Source: google/jax

Hi,

This might be misunderstanding, but the following errors:

def iterate(f, init_row, steps):
  def body_fun(step, result):
    return f(result)
  return lax.fori_loop(0, steps, body_fun, init_row)

def test_get(a,b): return iterate(lambda x: x[a], b, 10)
test_get_jit = jit(test_get)
test_get_jit(np.array([0,1,2,3]), np.array([0,1,2,3]))
/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py in Constant(self, py_val)
    319       return _constant_handlers[py_type](self, py_val)
    320     else:
--> 321       raise TypeError("No constant handler for type: {}".format(py_type))
    322 
    323 

TypeError: No constant handler for type: <class 'jax.interpreters.partial_eval.JaxprTracer'>
enhancement

Most helpful comment

Yes, pip install git+https://github.com/google/jax should work for that.

All 11 comments

I'm glad you're pushing the boundaries here!

The control flow constructs aren't really ready for use yet. However, if you want to use them now, you have to give them functions that don't close over any traced values. (Soon control flow constructs will handle closed-over traced values correctly, as the machinery exists in JAX's core.py, but for now they don't.)

In this case, f closes over the value referred to as a in the body of test_get, which is traced by the use of jit. You'd need to manually thread that value into the initial loop carry of fori_loop for it to work.

Got it! Thanks a lot.

Reopening this since I now understand it was meant to be tracked for future improvement.

We're on it!

I am facing a similar problem. What's the recommended workaround for the above example?

Instead of closing over values, you need to add them as formal parameters to body_fun and cond_fun and thread them through the loop carry. In other words, init_fun and body_fun shouldn't close over anything that could be traced.

Thanks for the fast reply. That鈥檚 unfortunate, because in my case the closure is nested quite deep in a sequence of nested functions, and the most outer one is body_fun. Together with #282, these two issues make code quite ugly (if one wants to use jit), because I need to have lots of top-level functions with lots of arguments that cannot be simplified using closures or partials (except maybe by manually unpacking the partials objects into function and data). Nevertheless, JAX is great. Looking forward to contributing more.

Just merged #334 which should allow closed-over tracers in the cond/body functions of _while_loop and fori_loop. It's only lightly tested so let us know if it works!

Thanks for fixing this!

I would like to try this out. I'm working in Colab. Is there an easy way to install the master branch?

Related: how often are new releases pushed to pypi? (is there a nightly??)

We鈥檝e just been updating pypi manually. The last update was yesterday I think. Doing a pip install 鈥攗pgrade jax will do it. (The jaxlib code is just for updating XLA, which we do less frequently.)

I鈥檓 not sure if you can point pip to a git repo, but that would be useful.

Yes, pip install git+https://github.com/google/jax should work for that.

Was this page helpful?
0 / 5 - 0 ratings