Repro case:
!pip install -U -q jax jaxlib
import jax, jax.numpy as jnp, functools
@functools.partial(jax.custom_jvp, nondiff_argnums=(0, 2))
def sample(shape, param, seed):
return jax.random.gamma(key=seed, shape=shape, a=param)
@sample.defjvp
def sample_jvp(shape, seed, primals, tangents):
param, = primals
dparam, = tangents
dparam = jnp.broadcast_to(dparam, shape)
samples = sample(shape, param, seed)
return samples, samples * dparam # dummy jvp for proof of concept
jax.vmap(lambda seed: sample((2,3), 1., seed))(
jax.random.split(jax.random.PRNGKey(1), 10))
=>
/usr/local/lib/python3.6/dist-packages/jax/custom_derivatives.py in bind(self, fun, jvp, *args)
285 _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
286 if env_trace_todo:
--> 287 raise core.UnexpectedTracerError
288 return map(core.full_lower, outs)
289
UnexpectedTracerError:
Note, similar happens for jit of lambda seed: ...
i.e. same exception for
jax.jit(lambda seed: sample((2,3), 1., seed))(jax.random.PRNGKey(1))
Thanks for flagging this, and the great repros. I think this is a duplicate of a known bug, which I should look up and link here. The basic issue is that nondiff_argnums works basically like lexical closure over tracers, but that check is disallowing lexical closure over tracers!
@NeilGirdhar has a nice solution, and I believe a PR that I _still_ haven't gotten to.
@brianwa84 how blocking is this for your work? I'd like to get to it this week, but I'm just trying to prioritize stuff. Let me know :)
@brianwa84 I might have a work-around for your issue. If you pip install tjax, then this should work:
import jax, jax.numpy as jnp, functools
from tjax import custom_jvp
@functools.partial(custom_jvp, nondiff_argnums=(0, 2))
def sample(shape, param, seed):
return jax.random.gamma(key=seed, shape=shape, a=param)
@sample.defjvp
def sample_jvp(shape, seed, primals, tangents):
param, = primals
dparam, = tangents
dparam = jnp.broadcast_to(dparam, shape)
samples = sample(shape, param, seed)
return samples, samples * dparam # dummy jvp for proof of concept
jax.vmap(lambda seed: sample((2,3), 1., seed))(
jax.random.split(jax.random.PRNGKey(1), 10))
This is a work-around that provides a separate nondiff_argnums (arguments for which you don't want tangents) and static_argnums (arguments that must be passed statically like your own custom non-pytree classes). I wasn't able to make a patch because I don't know the JAX internals well-enough. (Ideally, no tracers would be created for the tangents of nondiff arguments, but the workaround simply throws those out.)
I'm not blocked, I worked around it. Just wanted to report before I forgot.
:-)
On Wed, Aug 5, 2020, 12:53 AM Neil notifications@github.com wrote:
@brianwa84 https://github.com/brianwa84 I might have a work-around for
your issue. If you pip install tjax, then this should work:import jax, jax.numpy as jnp, functoolsfrom tjax import custom_jvp
@functools.partial(custom_jvp, nondiff_argnums=(0, 2))def sample(shape, param, seed):
return jax.random.gamma(key=seed, shape=shape, a=param)
@sample.defjvpdef sample_jvp(shape, seed, primals, tangents):
param, = primals
dparam, = tangents
dparam = jnp.broadcast_to(dparam, shape)
samples = sample(shape, param, seed)
return samples, samples * dparam # dummy jvp for proof of concept
jax.vmap(lambda seed: sample((2,3), 1., seed))(
jax.random.split(jax.random.PRNGKey(1), 10))This is a work-around that provides a separate nondiff_argnums (arguments
for which you don't want tangents) and static_argnums (arguments that
must be passed statically like your own custom non-pytree classes). I
wasn't able to make a patch because I don't know the JAX internals
well-enough. (Ideally, no tracers would be created for the tangents of
nondiff arguments, but the workaround simply throws those out.)—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/3964#issuecomment-668978653, or
unsubscribe
https://github.com/notifications/unsubscribe-auth/AFJFSI3D2HSOEM2LKWVPA3TR7DQS3ANCNFSM4PU66GHA
.
This was fixed by #4008 !
Most helpful comment
https://github.com/google/jax/issues/2912 ?