I get the following exception trying to take the grad of a cumprod-using function, e.g.
jax.jacrev(lambda x: jnp.cumprod(x, axis=-1))(np.random.randn(5))
File "jax/api.py", line 623, in batched_fun
out_flat = batching.batch(jaxtree_fun, in_flat, in_axes_, out_axes)
File "jax/interpreters/batching.py", line 45, in batch
return batch_transform(fun, sz, in_dims, out_dim_dst).call_wrapped(in_vals)
File "jax/linear_util.py", line 161, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "jax/api.py", line 501, in jacfun
y, pullback = vjp(f_partial, *dyn_args)
File "jax/api.py", line 1019, in vjp
out_primal, out_vjp = ad.vjp(jaxtree_fun, primals_flat)
File "jax/interpreters/ad.py", line 105, in vjp
out_primal, pval, jaxpr, consts = linearize(traceable, *primals)
File "jax/interpreters/ad.py", line 94, in linearize
jaxpr, out_pval, consts = pe.trace_to_jaxpr(jvpfun, in_pvals)
File "jax/interpreters/partial_eval.py", line 400, in trace_to_jaxpr
jaxpr, (out_pval, consts, env) = fun.call_wrapped(pvals)
File "jax/linear_util.py", line 161, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
... my code, ultimately calling np.cumprod ...
File "jax/numpy/lax_numpy.py", line 1206, in cumulative_reduction
return _cumulative_reduction(a, axis, dtype)
File "jax/api.py", line 151, in f_jitted
device_assignment=device_assignment)
File "jax/core.py", line 675, in call_bind
ans = full_lower(top_trace.process_call(primitive, f, tracers, params))
File "jax/interpreters/ad.py", line 260, in process_call
result = call_primitive.bind(f_jvp, pack(primals), nonzero_tangents, **params)
File "jax/core.py", line 675, in call_bind
ans = full_lower(top_trace.process_call(primitive, f, tracers, params))
File "jax/interpreters/partial_eval.py", line 116, in process_call
out_pv_const, consts = call_primitive.bind(fun, *in_consts, **params)
File "jax/core.py", line 675, in call_bind
ans = full_lower(top_trace.process_call(primitive, f, tracers, params))
File "jax/interpreters/batching.py", line 135, in process_call
val_out = call_primitive.bind(f, *vals, **params)
File "jax/core.py", line 672, in call_bind
ans = primitive.impl(f, *args, **params)
File "jax/interpreters/xla.py", line 667, in _xla_call_impl
*map(abstractify, args))
File "jax/linear_util.py", line 213, in cached_fun
ans, f_prev = cached_fun_body(f, args)
File "jax/linear_util.py", line 210, in cached_fun_body
return call(f, *args), f
File "jax/interpreters/xla.py", line 679, in _xla_callable
jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
File "jax/linear_util.py", line 161, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "jax/numpy/lax_numpy.py", line 1201, in _cumulative_reduction
a, window_dims, strides, xla_client.PaddingType.VALID)
File "jax/lax/lax.py", line 939, in _reduce_window_prod
window_strides=tuple(window_strides), padding=padding)
File "jax/core.py", line 148, in bind
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
File "jax/interpreters/ad.py", line 251, in process_primitive
.format(primitive))
NotImplementedError: Forward-mode differentiation rule for 'reduce_window' not implemented
It's easy enough to mimic TF's behavior here:
def cumprod_jvp(g, ans, x):
return jnp.cumsum(g / x) * ans
cumprod = custom_transforms(jnp.cumprod)
defjvp(cumprod, cumprod_jvp)
This has two downsides:
a) it doesn't support the extra keyword arguments (dtype and axis). We need to extend custom_transforms a bit to allow that.
b) it doesn't work correctly if any entry in x is 0. Note that TF has the same bug (mishandling 0s). PyTorch does not have this bug because it falls back to a much more expensive quadratic algorithm if any entry is 0.
(There's also a clearly correct solution by rewriting cumprod using lax.scan, but it will most likely be slower than the current implementation.)
this is also something that we are interested in using in pyhf to compute hessians of likelihoods cc @kratsg @matthewfeickert
This should be fixed at head. Hope that helps!
Most helpful comment
this is also something that we are interested in using in pyhf to compute hessians of likelihoods cc @kratsg @matthewfeickert