Hi all
I have a difficult bug which I don't understand. It started appearing when I defined my custom gradients with custom_jvp (in forward mode), before, I did it in reverse mode with custom_gradient.
I'm up-to-date on the master branch, as of this writing.
TypeError: <class 'jax.ad_util.Zero'> is not a valid Jax type
I've tried to make a simple code snippet to reproduce it, but it's still rather complex because the bug, as far as I can tell, only appears when I combine different things.
I've tried lots of combinations, and so far I've seen the bug appear only when:
1) I define the gradient with custom_jvp
2) The custom gradient depends on dA (A does not depend on theta, I think it's got something to do with this)
3) I use lax.scan (the bug disappears when I use a python loop)
4) theta is used as below, directly in step (this is not breaking the 'pure function' requirement, right?)
import jax
import jax.numpy as np
@jax.custom_jvp
def f(A, b):
return A @ b
def f_jvp(primals, tangents):
A, b = primals
dA, db = tangents
z = f(A, b)
dz = dA @ db
return z, dz
f.defjvp(f_jvp)
def experiment(theta):
def step(q, _):
z = f(np.eye(3), np.ones(3) * theta)
q += z[0]
return q, q
q = 0.
q, _ = jax.lax.scan(step, q, None, 4)
return q
experiment_grad = jax.grad(experiment)
g = experiment_grad(1.)
print(g)
Thanks for the help!
Rembert
This is the stack trace:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
~/jax/jax/core.py in concrete_aval(x)
676 try:
--> 677 return pytype_aval_mappings[type(x)](x)
678 except KeyError as err:
KeyError: <class 'jax.ad_util.Zero'>
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
<ipython-input-2-b58b959b770c> in <module>
29 experiment_grad = jax.grad(experiment)
30
---> 31 g = experiment_grad(1.)
32 print(g)
~/jax/jax/api.py in grad_f(*args, **kwargs)
370 @wraps(fun, docstr=docstr, argnums=argnums)
371 def grad_f(*args, **kwargs):
--> 372 _, g = value_and_grad_f(*args, **kwargs)
373 return g
374
~/jax/jax/api.py in value_and_grad_f(*args, **kwargs)
426 f_partial, dyn_args = argnums_partial(f, argnums, args)
427 if not has_aux:
--> 428 ans, vjp_py = _vjp(f_partial, *dyn_args)
429 else:
430 ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
~/jax/jax/api.py in _vjp(fun, *primals, **kwargs)
1386 if not has_aux:
1387 flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1388 out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
1389 out_tree = out_tree()
1390 else:
~/jax/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
104 def vjp(traceable, primals, has_aux=False):
105 if not has_aux:
--> 106 out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
107 else:
108 out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
~/jax/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
93 _, in_tree = tree_flatten(((primals, primals), {}))
94 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
---> 95 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
96 pval_primals, pval_tangents = tree_unflatten(out_tree(), out_pvals)
97 aval_primals, const_primals = unzip2(pval_primals)
~/jax/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom)
372 with new_master(trace_type, bottom=bottom) as master:
373 fun = trace_to_subjaxpr(fun, master, instantiate)
--> 374 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
375 assert not env
376 del master
~/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
148 gen = None
149
--> 150 ans = self.f(*args, **dict(self.params, **kwargs))
151 del args
152 while stack:
<ipython-input-2-b58b959b770c> in experiment(theta)
23
24 q = 0.
---> 25 q, _ = jax.lax.scan(step, q, None, 4)
26 return q
27
~/jax/jax/lax/lax_control_flow.py in scan(f, init, xs, length)
846 x_dtypes = [x.dtype for x in xs_flat]
847 x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes))
--> 848 jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals)
849 out_tree_children = out_tree.children()
850 if len(out_tree_children) != 2:
~/jax/jax/lax/lax_control_flow.py in _initial_style_jaxpr(fun, in_tree, in_avals)
60 with core.initial_style_staging():
61 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
---> 62 wrapped_fun, in_pvals, instantiate=True, stage_out=False)
63 out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
64 const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
~/jax/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom)
372 with new_master(trace_type, bottom=bottom) as master:
373 fun = trace_to_subjaxpr(fun, master, instantiate)
--> 374 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
375 assert not env
376 del master
~/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
148 gen = None
149
--> 150 ans = self.f(*args, **dict(self.params, **kwargs))
151 del args
152 while stack:
<ipython-input-2-b58b959b770c> in step(q, _)
18 def experiment(theta):
19 def step(q, _):
---> 20 z = f(np.eye(3), np.ones(3) * theta)
21 q += z[0]
22 return q, q
~/jax/jax/custom_derivatives.py in __call__(self, *args, **kwargs)
211 flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree)
212 if core.trace_state.initial_style:
--> 213 out_flat = custom_jvp_call_jaxpr(flat_fun, flat_jvp, *args_flat)
214 out_tree = out_tree1()
215 else:
~/jax/jax/custom_derivatives.py in custom_jvp_call_jaxpr(fun, jvp, *args)
280 jvp_jaxpr_thunk = _memoize(lambda: _initial_style_jaxpr(jvp, in_avals * 2))
281 return custom_jvp_call_jaxpr_p.bind(*args, fun_jaxpr=fun_jaxpr,
--> 282 jvp_jaxpr_thunk=jvp_jaxpr_thunk)
283
284 def _custom_jvp_call_jaxpr_impl(*args, fun_jaxpr, **_):
~/jax/jax/core.py in bind(self, *args, **kwargs)
200
201 tracers = map(top_trace.full_raise, args)
--> 202 out_tracer = top_trace.process_primitive(self, tracers, kwargs)
203 if self.multiple_results:
204 return map(full_lower, out_tracer)
~/jax/jax/interpreters/ad.py in process_primitive(self, primitive, tracers, params)
300 "Forward-mode differentiation rule for '{}' not implemented"
301 .format(primitive)) from err
--> 302 primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
303 if primitive.multiple_results:
304 return [JVPTracer(self, x, t) for x, t in zip(primal_out, tangent_out)]
~/jax/jax/custom_derivatives.py in _custom_jvp_call_jaxpr_jvp(primals, tangents, fun_jaxpr, jvp_jaxpr_thunk)
295 def _custom_jvp_call_jaxpr_jvp(primals, tangents, *, fun_jaxpr, jvp_jaxpr_thunk):
296 jvp_jaxpr = jvp_jaxpr_thunk()
--> 297 outs = core.jaxpr_as_fun(jvp_jaxpr)(*(primals + tangents))
298 return split_list(outs, [len(outs) // 2])
299 ad.primitive_jvps[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_jvp
~/jax/jax/core.py in jaxpr_as_fun(typed_jaxpr, *args)
108 @curry
109 def jaxpr_as_fun(typed_jaxpr: TypedJaxpr, *args):
--> 110 return eval_jaxpr(typed_jaxpr.jaxpr, typed_jaxpr.literals, *args)
111
112
~/jax/jax/core.py in eval_jaxpr(jaxpr, consts, *args)
267 else:
268 subfuns = []
--> 269 ans = eqn.primitive.bind(*(subfuns + in_vals), **params)
270 if eqn.primitive.multiple_results:
271 map(write, eqn.outvars, ans)
~/jax/jax/core.py in bind(self, *args, **kwargs)
200
201 tracers = map(top_trace.full_raise, args)
--> 202 out_tracer = top_trace.process_primitive(self, tracers, kwargs)
203 if self.multiple_results:
204 return map(full_lower, out_tracer)
~/jax/jax/interpreters/partial_eval.py in process_primitive(self, primitive, tracers, params)
97 return custom_partial_eval_rules[primitive](self, *tracers, **params)
98 else:
---> 99 return self.default_process_primitive(primitive, tracers, params)
100
101 def default_process_primitive(self, primitive, tracers, params):
~/jax/jax/interpreters/partial_eval.py in default_process_primitive(self, primitive, tracers, params)
103 if all(pv is None for pv in pvs):
104 return primitive.bind(*consts, **params)
--> 105 tracers = map(self.instantiate_const, tracers)
106 avals = [t.aval for t in tracers]
107 out_aval = primitive.abstract_eval(*avals, **params)
~/jax/jax/util.py in safe_map(f, *args)
32 for arg in args[1:]:
33 assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
---> 34 return list(map(f, *args))
35
36 def unzip2(xys):
~/jax/jax/interpreters/partial_eval.py in instantiate_const(self, tracer)
79 return self.new_instantiated_literal(const)
80 else:
---> 81 return self.new_instantiated_const(const)
82 else:
83 raise TypeError(pv)
~/jax/jax/interpreters/partial_eval.py in new_instantiated_const(self, val)
65
66 def new_instantiated_const(self, val):
---> 67 return JaxprTracer(self, PartialVal((get_aval(val), unit)), ConstVar(val))
68
69 def new_arg(self, pval):
~/jax/jax/core.py in get_aval(x)
684 return x.aval
685 else:
--> 686 return concrete_aval(x)
687
688
~/jax/jax/core.py in concrete_aval(x)
677 return pytype_aval_mappings[type(x)](x)
678 except KeyError as err:
--> 679 raise TypeError("{} is not a valid Jax type".format(type(x))) from err
680
681
TypeError: <class 'jax.ad_util.Zero'> is not a valid Jax type
Thanks for the report, and for testing out this new feature even before we released it! Yes, this looks like a bug to me.
By the way, the line dz = dA @ db in your JVP rule tripped me up. That's a math bug, even if it should be fine it to write it in a JVP rule, because the output gradient must be a linear function of the input gradients. But switching that to the correct dz = A @ db + dA @ b gives the same error, so that isn't the issue here.
Thanks guys!
Thanks for the report, and for testing out this new feature even before we released it! Yes, this looks like a bug to me.
Well I was inspired by @mattjj's great notebook on this, so I went ahead and defined my implicit gradients in forward mode, for which the math was a lot easier!
By the way, the line
dz = dA @ dbin your JVP rule tripped me up. That's a math bug, even if it should be fine it to write it in a JVP rule, because the output gradient must be a linear function of the input gradients. But switching that to the correctdz = A @ db + dA @ bgives the same error, so that isn't the issue here.
Yes I know the math is wrong, I was just trying all sorts of different JVP's to see the effect on the bug.
Really, one of the fun things in JAX debugging is you never know if it's a math bug or a software bug right? ;)
Most helpful comment
Thanks for the report, and for testing out this new feature even before we released it! Yes, this looks like a bug to me.
By the way, the line
dz = dA @ dbin your JVP rule tripped me up. That's a math bug, even if it should be fine it to write it in a JVP rule, because the output gradient must be a linear function of the input gradients. But switching that to the correctdz = A @ db + dA @ bgives the same error, so that isn't the issue here.