Jax: Add autodiff support for `reduce_window`

Created on 20 Aug 2019  路  4Comments  路  Source: google/jax

I get the following exception trying to take the grad of a cumprod-using function, e.g.

jax.jacrev(lambda x: jnp.cumprod(x, axis=-1))(np.random.randn(5))
File "jax/api.py", line 623, in batched_fun
    out_flat = batching.batch(jaxtree_fun, in_flat, in_axes_, out_axes)
  File "jax/interpreters/batching.py", line 45, in batch
    return batch_transform(fun, sz, in_dims, out_dim_dst).call_wrapped(in_vals)
  File "jax/linear_util.py", line 161, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "jax/api.py", line 501, in jacfun
    y, pullback = vjp(f_partial, *dyn_args)
  File "jax/api.py", line 1019, in vjp
    out_primal, out_vjp = ad.vjp(jaxtree_fun, primals_flat)
  File "jax/interpreters/ad.py", line 105, in vjp
    out_primal, pval, jaxpr, consts = linearize(traceable, *primals)
  File "jax/interpreters/ad.py", line 94, in linearize
    jaxpr, out_pval, consts = pe.trace_to_jaxpr(jvpfun, in_pvals)
  File "jax/interpreters/partial_eval.py", line 400, in trace_to_jaxpr
    jaxpr, (out_pval, consts, env) = fun.call_wrapped(pvals)
  File "jax/linear_util.py", line 161, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  ... my code, ultimately calling np.cumprod ...

  File "jax/numpy/lax_numpy.py", line 1206, in cumulative_reduction
    return _cumulative_reduction(a, axis, dtype)
  File "jax/api.py", line 151, in f_jitted
    device_assignment=device_assignment)
  File "jax/core.py", line 675, in call_bind
    ans = full_lower(top_trace.process_call(primitive, f, tracers, params))
  File "jax/interpreters/ad.py", line 260, in process_call
    result = call_primitive.bind(f_jvp, pack(primals), nonzero_tangents, **params)
  File "jax/core.py", line 675, in call_bind
    ans = full_lower(top_trace.process_call(primitive, f, tracers, params))
  File "jax/interpreters/partial_eval.py", line 116, in process_call
    out_pv_const, consts = call_primitive.bind(fun, *in_consts, **params)
  File "jax/core.py", line 675, in call_bind
    ans = full_lower(top_trace.process_call(primitive, f, tracers, params))
  File "jax/interpreters/batching.py", line 135, in process_call
    val_out = call_primitive.bind(f, *vals, **params)
  File "jax/core.py", line 672, in call_bind
    ans = primitive.impl(f, *args, **params)
  File "jax/interpreters/xla.py", line 667, in _xla_call_impl
    *map(abstractify, args))
  File "jax/linear_util.py", line 213, in cached_fun
    ans, f_prev = cached_fun_body(f, args)
  File "jax/linear_util.py", line 210, in cached_fun_body
    return call(f, *args), f
  File "jax/interpreters/xla.py", line 679, in _xla_callable
    jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
  File "jax/linear_util.py", line 161, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "jax/numpy/lax_numpy.py", line 1201, in _cumulative_reduction
    a, window_dims, strides, xla_client.PaddingType.VALID)
  File "jax/lax/lax.py", line 939, in _reduce_window_prod
    window_strides=tuple(window_strides), padding=padding)
  File "jax/core.py", line 148, in bind
    out_tracer = top_trace.process_primitive(self, tracers, kwargs)
  File "jax/interpreters/ad.py", line 251, in process_primitive
    .format(primitive))
NotImplementedError: Forward-mode differentiation rule for 'reduce_window' not implemented
enhancement

Most helpful comment

this is also something that we are interested in using in pyhf to compute hessians of likelihoods cc @kratsg @matthewfeickert

All 4 comments

It's easy enough to mimic TF's behavior here:

def cumprod_jvp(g, ans, x):
  return jnp.cumsum(g / x) * ans
cumprod = custom_transforms(jnp.cumprod)
defjvp(cumprod, cumprod_jvp)

This has two downsides:
a) it doesn't support the extra keyword arguments (dtype and axis). We need to extend custom_transforms a bit to allow that.
b) it doesn't work correctly if any entry in x is 0. Note that TF has the same bug (mishandling 0s). PyTorch does not have this bug because it falls back to a much more expensive quadratic algorithm if any entry is 0.

(There's also a clearly correct solution by rewriting cumprod using lax.scan, but it will most likely be slower than the current implementation.)

this is also something that we are interested in using in pyhf to compute hessians of likelihoods cc @kratsg @matthewfeickert

This should be fixed at head. Hope that helps!

Was this page helpful?
0 / 5 - 0 ratings