Hello
I'm trying to use lax.cond together with grad. According to this table this should be possible?
I've replicated the error with this simple code:
def f_1(x):
return x ** 2
def f_2(x):
return x ** 3
def f(x):
y = lax.cond(x > 0, x, f_1, x, f_2)
return y
y = grad(f)(3.)
print(y)
This raises:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-95-51fb82770856> in <module>
10 return y
11
---> 12 y = grad(f)(3.)
13 print(y)
~/one-leg/.venv/lib/python3.5/site-packages/jax/api.py in grad_f(*args, **kwargs)
354 @wraps(fun, docstr=docstr, argnums=argnums)
355 def grad_f(*args, **kwargs):
--> 356 _, g = value_and_grad_f(*args, **kwargs)
357 return g
358
~/one-leg/.venv/lib/python3.5/site-packages/jax/api.py in value_and_grad_f(*args, **kwargs)
409 f_partial, dyn_args = _argnums_partial(f, argnums, args)
410 if not has_aux:
--> 411 ans, vjp_py = vjp(f_partial, *dyn_args)
412 else:
413 ans, vjp_py, aux = vjp(f_partial, *dyn_args, has_aux=True)
~/one-leg/.venv/lib/python3.5/site-packages/jax/api.py in vjp(fun, *primals, **kwargs)
1287 if not has_aux:
1288 flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1289 out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
1290 out_tree = out_tree()
1291 else:
~/one-leg/.venv/lib/python3.5/site-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
107 def vjp(traceable, primals, has_aux=False):
108 if not has_aux:
--> 109 out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
110 else:
111 out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
~/one-leg/.venv/lib/python3.5/site-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
99 pval_primals, pval_tangents = tree_unflatten(out_tree(), out_pvals)
100 aval_primals, const_primals = unzip2(pval_primals)
--> 101 assert all(aval_primal is None for aval_primal in aval_primals)
102 if not has_aux:
103 return const_primals, pval_tangents, jaxpr, consts
AssertionError:
JAX's support for cond is very recent. When I run your example with version 0.1.59 installed (released last week), it works. Please try upgrading and let us know if that works!
I just double-checked the same! pip install --upgrade jax should fix this. Please reopen, or open a new issue, if you have any problems!
Great, this was solved indeed.
As always, thanks for the quick replies!
Most helpful comment
JAX's support for
condis very recent. When I run your example with version 0.1.59 installed (released last week), it works. Please try upgrading and let us know if that works!