Jax: Bad interaction between batch trace and custom jvp trace

Created on 5 Aug 2020  Â·  7Comments  Â·  Source: google/jax

Repro case:

!pip install -U -q jax jaxlib

import jax, jax.numpy as jnp, functools

@functools.partial(jax.custom_jvp, nondiff_argnums=(0, 2))
def sample(shape, param, seed):
  return jax.random.gamma(key=seed, shape=shape, a=param)

@sample.defjvp
def sample_jvp(shape, seed, primals, tangents):
  param, = primals
  dparam, = tangents
  dparam = jnp.broadcast_to(dparam, shape)
  samples = sample(shape, param, seed)
  return samples, samples * dparam  # dummy jvp for proof of concept

jax.vmap(lambda seed: sample((2,3), 1., seed))(
    jax.random.split(jax.random.PRNGKey(1), 10))

=>

/usr/local/lib/python3.6/dist-packages/jax/custom_derivatives.py in bind(self, fun, jvp, *args)
    285     _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
    286     if env_trace_todo:
--> 287       raise core.UnexpectedTracerError
    288     return map(core.full_lower, outs)
    289 

UnexpectedTracerError: 
bug

Most helpful comment

All 7 comments

Note, similar happens for jit of lambda seed: ...

i.e. same exception for
jax.jit(lambda seed: sample((2,3), 1., seed))(jax.random.PRNGKey(1))

Thanks for flagging this, and the great repros. I think this is a duplicate of a known bug, which I should look up and link here. The basic issue is that nondiff_argnums works basically like lexical closure over tracers, but that check is disallowing lexical closure over tracers!

@NeilGirdhar has a nice solution, and I believe a PR that I _still_ haven't gotten to.

@brianwa84 how blocking is this for your work? I'd like to get to it this week, but I'm just trying to prioritize stuff. Let me know :)

@brianwa84 I might have a work-around for your issue. If you pip install tjax, then this should work:

import jax, jax.numpy as jnp, functools
from tjax import custom_jvp

@functools.partial(custom_jvp, nondiff_argnums=(0, 2))
def sample(shape, param, seed):
  return jax.random.gamma(key=seed, shape=shape, a=param)

@sample.defjvp
def sample_jvp(shape, seed, primals, tangents):
  param, = primals
  dparam, = tangents
  dparam = jnp.broadcast_to(dparam, shape)
  samples = sample(shape, param, seed)
  return samples, samples * dparam  # dummy jvp for proof of concept

jax.vmap(lambda seed: sample((2,3), 1., seed))(
    jax.random.split(jax.random.PRNGKey(1), 10))

This is a work-around that provides a separate nondiff_argnums (arguments for which you don't want tangents) and static_argnums (arguments that must be passed statically like your own custom non-pytree classes). I wasn't able to make a patch because I don't know the JAX internals well-enough. (Ideally, no tracers would be created for the tangents of nondiff arguments, but the workaround simply throws those out.)

I'm not blocked, I worked around it. Just wanted to report before I forgot.
:-)

On Wed, Aug 5, 2020, 12:53 AM Neil notifications@github.com wrote:

@brianwa84 https://github.com/brianwa84 I might have a work-around for
your issue. If you pip install tjax, then this should work:

import jax, jax.numpy as jnp, functoolsfrom tjax import custom_jvp
@functools.partial(custom_jvp, nondiff_argnums=(0, 2))def sample(shape, param, seed):
return jax.random.gamma(key=seed, shape=shape, a=param)
@sample.defjvpdef sample_jvp(shape, seed, primals, tangents):
param, = primals
dparam, = tangents
dparam = jnp.broadcast_to(dparam, shape)
samples = sample(shape, param, seed)
return samples, samples * dparam # dummy jvp for proof of concept
jax.vmap(lambda seed: sample((2,3), 1., seed))(
jax.random.split(jax.random.PRNGKey(1), 10))

This is a work-around that provides a separate nondiff_argnums (arguments
for which you don't want tangents) and static_argnums (arguments that
must be passed statically like your own custom non-pytree classes). I
wasn't able to make a patch because I don't know the JAX internals
well-enough. (Ideally, no tracers would be created for the tangents of
nondiff arguments, but the workaround simply throws those out.)

—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/3964#issuecomment-668978653, or
unsubscribe
https://github.com/notifications/unsubscribe-auth/AFJFSI3D2HSOEM2LKWVPA3TR7DQS3ANCNFSM4PU66GHA
.

This was fixed by #4008 !

Was this page helpful?
0 / 5 - 0 ratings