Jax: Using grad on function using lax.cond raises unclear assertion error

Created on 18 Feb 2020  路  3Comments  路  Source: google/jax

Hello

I'm trying to use lax.cond together with grad. According to this table this should be possible?

I've replicated the error with this simple code:

def f_1(x):
    return x ** 2

def f_2(x):
    return x ** 3


def f(x):
    y = lax.cond(x > 0, x, f_1, x, f_2)
    return y

y = grad(f)(3.)
print(y)

This raises:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-95-51fb82770856> in <module>
     10     return y
     11 
---> 12 y = grad(f)(3.)
     13 print(y)

~/one-leg/.venv/lib/python3.5/site-packages/jax/api.py in grad_f(*args, **kwargs)
    354   @wraps(fun, docstr=docstr, argnums=argnums)
    355   def grad_f(*args, **kwargs):
--> 356     _, g = value_and_grad_f(*args, **kwargs)
    357     return g
    358 

~/one-leg/.venv/lib/python3.5/site-packages/jax/api.py in value_and_grad_f(*args, **kwargs)
    409     f_partial, dyn_args = _argnums_partial(f, argnums, args)
    410     if not has_aux:
--> 411       ans, vjp_py = vjp(f_partial, *dyn_args)
    412     else:
    413       ans, vjp_py, aux = vjp(f_partial, *dyn_args, has_aux=True)

~/one-leg/.venv/lib/python3.5/site-packages/jax/api.py in vjp(fun, *primals, **kwargs)
   1287   if not has_aux:
   1288     flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1289     out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
   1290     out_tree = out_tree()
   1291   else:

~/one-leg/.venv/lib/python3.5/site-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    107 def vjp(traceable, primals, has_aux=False):
    108   if not has_aux:
--> 109     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    110   else:
    111     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)

~/one-leg/.venv/lib/python3.5/site-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
     99   pval_primals, pval_tangents = tree_unflatten(out_tree(), out_pvals)
    100   aval_primals, const_primals = unzip2(pval_primals)
--> 101   assert all(aval_primal is None for aval_primal in aval_primals)
    102   if not has_aux:
    103     return const_primals, pval_tangents, jaxpr, consts

AssertionError: 
question

Most helpful comment

JAX's support for cond is very recent. When I run your example with version 0.1.59 installed (released last week), it works. Please try upgrading and let us know if that works!

All 3 comments

JAX's support for cond is very recent. When I run your example with version 0.1.59 installed (released last week), it works. Please try upgrading and let us know if that works!

I just double-checked the same! pip install --upgrade jax should fix this. Please reopen, or open a new issue, if you have any problems!

Great, this was solved indeed.
As always, thanks for the quick replies!

Was this page helpful?
0 / 5 - 0 ratings