Jax: Assertion caused by scan, vmap, and custom derivatives

Created on 9 Oct 2020  路  8Comments  路  Source: google/jax

It seems like there's a bad interaction between vmap and something in custom derivatives:

  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/custom_derivatives.py", line 464, in __call__
    out_flat = custom_vjp_call_jaxpr(flat_fun, flat_fwd, flat_bwd,
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/custom_derivatives.py", line 530, in custom_vjp_call_jaxpr
    return custom_vjp_call_jaxpr_p.bind(*args, fun_jaxpr=fun_jaxpr,
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/core.py", line 266, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/batching.py", line 159, in process_primitive
    return map(partial(BatchTracer, self), val_out, dim_out)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/util.py", line 34, in safe_map
    assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
AssertionError: length mismatch: [13, 100]

I've spent about a day trying to make a minimum working example. I'll keep trying, but just sharing this in case this jumps out as obvious to someone.

bug

All 8 comments

Any chance you're defining your custom_vjp function inside another function or a class? My understanding is that custom_vjp/custom_jvp should always be written as "top level" functions (i.e., with no indentation).

Otherwise, this is definitely a bug -- the newer custom derivatives interface was designed specifically around handling vmap of custom derivatives properly.

ping @mattjj

@shoyer I am defining a custom_vjp, but it's on a top-level function.

Here's a minimal working example:

from jax import custom_vjp
from jax import numpy as jnp
from jax import vjp, vmap
from jax.lax import scan


@custom_vjp
def g(x, y):
    return None

def g_fwd(x, y):
    return None, y

def g_bwd(residuals, z_bar):
    assert False

g.defvjp(g_fwd, g_bwd)

def f(xs, y):
    v_g = vmap(g, in_axes=(0, None), out_axes=None)
    v_g(xs, y)

def scan_body(xs, _):
    y = jnp.zeros(1)
    _, vjp_f = vjp(f, xs, y)
    vjp_f(None)
    return xs, None

scan(scan_body, jnp.ones(5), None, 100)

Produces:

  File "a.py", line 21, in f
    v_g(xs, y)
  File "...jax/traceback_util.py", line 137, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "...jax/api.py", line 1230, in batched_fun
    out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
  File "...jax/interpreters/batching.py", line 36, in batch
    return batched_fun.call_wrapped(*in_vals)
  File "...jax/linear_util.py", line 151, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "...jax/custom_derivatives.py", line 464, in __call__
    out_flat = custom_vjp_call_jaxpr(flat_fun, flat_fwd, flat_bwd,
  File "...jaxpr
    return custom_vjp_call_jaxpr_p.bind(*args, fun_jaxpr=fun_jaxpr,
  File "...jax/core.py", line 266, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "...jax/interpreters/batching.py", line 159, in process_primitive
    return map(partial(BatchTracer, self), val_out, dim_out)
  File "...jax/util.py", line 34, in safe_map
    assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
AssertionError: length mismatch: [0, 1]

Thanks, @NeilGirdhar , not only for this report and clear minimal reproduction, but also for your extreme patience putting up with custom_vjp bugs as I delayed actually fixing things once and for all (my excuse was landing #3370).

But we will soon awake from that long nightmare! I have a branch, just pushed to #4008 on Friday, that will close all open issues about custom_jvp and custom_vjp closure difficulties, and I think will eliminate the possibility of these bugs in the future.

When I tried this repro on #4008 I got this error instead:

ValueError: vmap out_axes specification must be a tree prefix of the corresponding value, got specification () for value tree PyTreeDef(None, []).

Is that an error we expect?

@shoyer by the way, #4008 removes the top-level-ness requirement for custom_jvp/vjp. See the tests!

@mattjj No problem at all. It's my pleasure to produce MWEs even though it's time consuming especially since the JAX team has been exceptionally fast at responding to issues, so thank you very much for that.

It looks like you've fixed the bug! I've updated my MWE in case you want to add it or something like it to your testing. I guess you can add "Fixes #4521" to the description of #4008 so that this will be automatically closed?

Thanks again for addressing this so fast!

@NeilGirdhar Yeah perfect, I'll add your MWE as a test, and I'll add this issue to the "fixes" list.

Closed by #4008

Was this page helpful?
0 / 5 - 0 ratings

Related issues

kunc picture kunc  路  3Comments

shannon63 picture shannon63  路  3Comments

sussillo picture sussillo  路  3Comments

alexbw picture alexbw  路  3Comments

DylanMuir picture DylanMuir  路  3Comments