(Thanks to @John-Jumper for pointing this out.)
@jax.custom_transforms
def f(x):
return 2. * x
jax.defvjp_all(f, lambda x: (2. * x, lambda _: (3.,)))
print('grad', jax.grad(f)(1.))
print('vmap-of-grad', jax.vmap(jax.grad(f))(np.ones(4)))
print('grad-of-sum-vmap', jax.grad(lambda x: jax.vmap(f)(x).sum())(np.ones(4)))
grad 3.0
vmap-of-grad [3. 3. 3. 3.]
grad-of-sum-vmap [2. 2. 2. 2.]
When you write a custom_transforms primitive, you're basically saying "here is an implementation of this primitive that's valid for everything except the JVP interpreter; for that, use this other implementation." When you vmap that primitive, it vmaps the first implementation; when you run the JVP interpreter on the result of that (e.g. by calling grad), the primitive isn't there any more to be overridden, so JVP tries to work with what it's given.
This is arguably expected/correct behavior in the case where custom_transforms overrides preserve semantics (maybe they improve numerics or performance) but we're seeing custom_transforms be used in ways that go further than that (e.g. giving a JVP rule for a function JAX can't differentiate).
One fix for this would involve essentially coercing jvp(vmap(f)) to vmap(jvp(f)) (i.e., making sure the overridden interpreter always ends up as the innermost trace). But in general our transformations aren't commutative, so even that's somewhat limited.
For those running into this problem: the simplest solution is to manually move the overridden transformation inside any other (so if your code has grad(vmap(f)) try moving the vmap outside the grad). A more general, but more complex, user-level fix is to add a custom_transforms overload to define vmap(f) in terms of f (perhaps by making f rank-polymorphic).
Whew, #2026 fixed this!
Most helpful comment
When you write a
custom_transformsprimitive, you're basically saying "here is an implementation of this primitive that's valid for everything except the JVP interpreter; for that, use this other implementation." When youvmapthat primitive, itvmaps the first implementation; when you run the JVP interpreter on the result of that (e.g. by callinggrad), the primitive isn't there any more to be overridden, so JVP tries to work with what it's given.This is arguably expected/correct behavior in the case where
custom_transformsoverrides preserve semantics (maybe they improve numerics or performance) but we're seeingcustom_transformsbe used in ways that go further than that (e.g. giving a JVP rule for a function JAX can't differentiate).One fix for this would involve essentially coercing
jvp(vmap(f))tovmap(jvp(f))(i.e., making sure the overridden interpreter always ends up as the innermost trace). But in general our transformations aren't commutative, so even that's somewhat limited.For those running into this problem: the simplest solution is to manually move the overridden transformation inside any other (so if your code has
grad(vmap(f))try moving thevmapoutside thegrad). A more general, but more complex, user-level fix is to add acustom_transformsoverload to definevmap(f)in terms off(perhaps by makingfrank-polymorphic).