@duvenaud sent me a repro. This might be the same as #2912.
@duvenaud You could try my shim over custom_vjp that separates the concepts of nondiff_argnums and static_argnums.
@NeilGirdhar thanks for pointing that out; I meant to ask you about that. Is the idea that "static" args are things like functions (i.e. not jax value types), while nondiff args are jax values that we don't want to differentiate with respect to (like rng keys and the like)?
@mattjj exactly!
Seems smart! I swear I'll get to helping out with this, just a few more fires to put out first...
This was fixed by #4008 !
Most helpful comment
Seems smart! I swear I'll get to helping out with this, just a few more fires to put out first...