Jax: TypeError: <class 'jax.ad_util.Zero'> is not a valid Jax type when I combine custom_jvp and lax.scan

Created on 9 Apr 2020  路  4Comments  路  Source: google/jax

Hi all

I have a difficult bug which I don't understand. It started appearing when I defined my custom gradients with custom_jvp (in forward mode), before, I did it in reverse mode with custom_gradient.
I'm up-to-date on the master branch, as of this writing.

TypeError: <class 'jax.ad_util.Zero'> is not a valid Jax type

I've tried to make a simple code snippet to reproduce it, but it's still rather complex because the bug, as far as I can tell, only appears when I combine different things.
I've tried lots of combinations, and so far I've seen the bug appear only when:
1) I define the gradient with custom_jvp
2) The custom gradient depends on dA (A does not depend on theta, I think it's got something to do with this)
3) I use lax.scan (the bug disappears when I use a python loop)
4) theta is used as below, directly in step (this is not breaking the 'pure function' requirement, right?)

import jax
import jax.numpy as np

@jax.custom_jvp
def f(A, b):
    return A @ b

def f_jvp(primals, tangents):
    A, b = primals
    dA, db = tangents
    z = f(A, b)
    dz = dA @ db
    return z, dz

f.defjvp(f_jvp)


def experiment(theta):
    def step(q, _):
        z = f(np.eye(3), np.ones(3) * theta)
        q += z[0]
        return q, q

    q = 0.
    q, _ = jax.lax.scan(step, q, None, 4)
    return q


experiment_grad = jax.grad(experiment)

g = experiment_grad(1.)
print(g)

Thanks for the help!

Rembert

bug

Most helpful comment

Thanks for the report, and for testing out this new feature even before we released it! Yes, this looks like a bug to me.

By the way, the line dz = dA @ db in your JVP rule tripped me up. That's a math bug, even if it should be fine it to write it in a JVP rule, because the output gradient must be a linear function of the input gradients. But switching that to the correct dz = A @ db + dA @ b gives the same error, so that isn't the issue here.

All 4 comments

This is the stack trace:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~/jax/jax/core.py in concrete_aval(x)
    676   try:
--> 677     return pytype_aval_mappings[type(x)](x)
    678   except KeyError as err:

KeyError: <class 'jax.ad_util.Zero'>

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
<ipython-input-2-b58b959b770c> in <module>
     29 experiment_grad = jax.grad(experiment)
     30 
---> 31 g = experiment_grad(1.)
     32 print(g)

~/jax/jax/api.py in grad_f(*args, **kwargs)
    370   @wraps(fun, docstr=docstr, argnums=argnums)
    371   def grad_f(*args, **kwargs):
--> 372     _, g = value_and_grad_f(*args, **kwargs)
    373     return g
    374 

~/jax/jax/api.py in value_and_grad_f(*args, **kwargs)
    426     f_partial, dyn_args = argnums_partial(f, argnums, args)
    427     if not has_aux:
--> 428       ans, vjp_py = _vjp(f_partial, *dyn_args)
    429     else:
    430       ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)

~/jax/jax/api.py in _vjp(fun, *primals, **kwargs)
   1386   if not has_aux:
   1387     flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1388     out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
   1389     out_tree = out_tree()
   1390   else:

~/jax/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    104 def vjp(traceable, primals, has_aux=False):
    105   if not has_aux:
--> 106     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    107   else:
    108     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)

~/jax/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
     93   _, in_tree = tree_flatten(((primals, primals), {}))
     94   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
---> 95   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
     96   pval_primals, pval_tangents = tree_unflatten(out_tree(), out_pvals)
     97   aval_primals, const_primals = unzip2(pval_primals)

~/jax/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom)
    372   with new_master(trace_type, bottom=bottom) as master:
    373     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 374     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    375     assert not env
    376     del master

~/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    148     gen = None
    149 
--> 150     ans = self.f(*args, **dict(self.params, **kwargs))
    151     del args
    152     while stack:

<ipython-input-2-b58b959b770c> in experiment(theta)
     23 
     24     q = 0.
---> 25     q, _ = jax.lax.scan(step, q, None, 4)
     26     return q
     27 

~/jax/jax/lax/lax_control_flow.py in scan(f, init, xs, length)
    846   x_dtypes = [x.dtype for x in xs_flat]
    847   x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes))
--> 848   jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals)
    849   out_tree_children = out_tree.children()
    850   if len(out_tree_children) != 2:

~/jax/jax/lax/lax_control_flow.py in _initial_style_jaxpr(fun, in_tree, in_avals)
     60   with core.initial_style_staging():
     61     jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
---> 62       wrapped_fun, in_pvals, instantiate=True, stage_out=False)
     63   out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
     64   const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts)

~/jax/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom)
    372   with new_master(trace_type, bottom=bottom) as master:
    373     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 374     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    375     assert not env
    376     del master

~/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    148     gen = None
    149 
--> 150     ans = self.f(*args, **dict(self.params, **kwargs))
    151     del args
    152     while stack:

<ipython-input-2-b58b959b770c> in step(q, _)
     18 def experiment(theta):
     19     def step(q, _):
---> 20         z = f(np.eye(3), np.ones(3) * theta)
     21         q += z[0]
     22         return q, q

~/jax/jax/custom_derivatives.py in __call__(self, *args, **kwargs)
    211     flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree)
    212     if core.trace_state.initial_style:
--> 213       out_flat = custom_jvp_call_jaxpr(flat_fun, flat_jvp, *args_flat)
    214       out_tree = out_tree1()
    215     else:

~/jax/jax/custom_derivatives.py in custom_jvp_call_jaxpr(fun, jvp, *args)
    280   jvp_jaxpr_thunk = _memoize(lambda: _initial_style_jaxpr(jvp, in_avals * 2))
    281   return custom_jvp_call_jaxpr_p.bind(*args, fun_jaxpr=fun_jaxpr,
--> 282                                       jvp_jaxpr_thunk=jvp_jaxpr_thunk)
    283 
    284 def _custom_jvp_call_jaxpr_impl(*args, fun_jaxpr, **_):

~/jax/jax/core.py in bind(self, *args, **kwargs)
    200 
    201     tracers = map(top_trace.full_raise, args)
--> 202     out_tracer = top_trace.process_primitive(self, tracers, kwargs)
    203     if self.multiple_results:
    204       return map(full_lower, out_tracer)

~/jax/jax/interpreters/ad.py in process_primitive(self, primitive, tracers, params)
    300           "Forward-mode differentiation rule for '{}' not implemented"
    301           .format(primitive)) from err
--> 302     primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
    303     if primitive.multiple_results:
    304       return [JVPTracer(self, x, t) for x, t in zip(primal_out, tangent_out)]

~/jax/jax/custom_derivatives.py in _custom_jvp_call_jaxpr_jvp(primals, tangents, fun_jaxpr, jvp_jaxpr_thunk)
    295 def _custom_jvp_call_jaxpr_jvp(primals, tangents, *, fun_jaxpr, jvp_jaxpr_thunk):
    296   jvp_jaxpr = jvp_jaxpr_thunk()
--> 297   outs = core.jaxpr_as_fun(jvp_jaxpr)(*(primals + tangents))
    298   return split_list(outs, [len(outs) // 2])
    299 ad.primitive_jvps[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_jvp

~/jax/jax/core.py in jaxpr_as_fun(typed_jaxpr, *args)
    108 @curry
    109 def jaxpr_as_fun(typed_jaxpr: TypedJaxpr, *args):
--> 110   return eval_jaxpr(typed_jaxpr.jaxpr, typed_jaxpr.literals, *args)
    111 
    112 

~/jax/jax/core.py in eval_jaxpr(jaxpr, consts, *args)
    267     else:
    268       subfuns = []
--> 269     ans = eqn.primitive.bind(*(subfuns + in_vals), **params)
    270     if eqn.primitive.multiple_results:
    271       map(write, eqn.outvars, ans)

~/jax/jax/core.py in bind(self, *args, **kwargs)
    200 
    201     tracers = map(top_trace.full_raise, args)
--> 202     out_tracer = top_trace.process_primitive(self, tracers, kwargs)
    203     if self.multiple_results:
    204       return map(full_lower, out_tracer)

~/jax/jax/interpreters/partial_eval.py in process_primitive(self, primitive, tracers, params)
     97       return custom_partial_eval_rules[primitive](self, *tracers, **params)
     98     else:
---> 99       return self.default_process_primitive(primitive, tracers, params)
    100 
    101   def default_process_primitive(self, primitive, tracers, params):

~/jax/jax/interpreters/partial_eval.py in default_process_primitive(self, primitive, tracers, params)
    103     if all(pv is None for pv in pvs):
    104       return primitive.bind(*consts, **params)
--> 105     tracers = map(self.instantiate_const, tracers)
    106     avals = [t.aval for t in tracers]
    107     out_aval = primitive.abstract_eval(*avals, **params)

~/jax/jax/util.py in safe_map(f, *args)
     32   for arg in args[1:]:
     33     assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
---> 34   return list(map(f, *args))
     35 
     36 def unzip2(xys):

~/jax/jax/interpreters/partial_eval.py in instantiate_const(self, tracer)
     79         return self.new_instantiated_literal(const)
     80       else:
---> 81         return self.new_instantiated_const(const)
     82     else:
     83       raise TypeError(pv)

~/jax/jax/interpreters/partial_eval.py in new_instantiated_const(self, val)
     65 
     66   def new_instantiated_const(self, val):
---> 67     return JaxprTracer(self, PartialVal((get_aval(val), unit)), ConstVar(val))
     68 
     69   def new_arg(self, pval):

~/jax/jax/core.py in get_aval(x)
    684     return x.aval
    685   else:
--> 686     return concrete_aval(x)
    687 
    688 

~/jax/jax/core.py in concrete_aval(x)
    677     return pytype_aval_mappings[type(x)](x)
    678   except KeyError as err:
--> 679     raise TypeError("{} is not a valid Jax type".format(type(x))) from err
    680 
    681 

TypeError: <class 'jax.ad_util.Zero'> is not a valid Jax type

Thanks for the report, and for testing out this new feature even before we released it! Yes, this looks like a bug to me.

By the way, the line dz = dA @ db in your JVP rule tripped me up. That's a math bug, even if it should be fine it to write it in a JVP rule, because the output gradient must be a linear function of the input gradients. But switching that to the correct dz = A @ db + dA @ b gives the same error, so that isn't the issue here.

2673 has the fix!

Thanks guys!

Thanks for the report, and for testing out this new feature even before we released it! Yes, this looks like a bug to me.

Well I was inspired by @mattjj's great notebook on this, so I went ahead and defined my implicit gradients in forward mode, for which the math was a lot easier!

By the way, the line dz = dA @ db in your JVP rule tripped me up. That's a math bug, even if it should be fine it to write it in a JVP rule, because the output gradient must be a linear function of the input gradients. But switching that to the correct dz = A @ db + dA @ b gives the same error, so that isn't the issue here.

Yes I know the math is wrong, I was just trying all sorts of different JVP's to see the effect on the bug.
Really, one of the fun things in JAX debugging is you never know if it's a math bug or a software bug right? ;)

Was this page helpful?
0 / 5 - 0 ratings