The addition of scan() is great! I anticipate its usage will be much wider if one can call vmap() over scanned functions. It's not yet implemented.
NotImplementedError: Batching rule for 'scan' not implemented
Thank you!
+1 on this too! I can get around this when doing jacrev for the time being but not jacfwd, but jacrev has other problems :)
To add on to this, I'm trying to better understand when batching is actually needed:
import jax
import jax.numpy as jnp
def fn(a):
return jnp.cos(a)
def loop(val):
iterations = 10000
def apply_carry(x, i):
return fn(x), i
final_val, _ = jax.lax.scan(
apply_carry,
val,
jnp.arange(iterations)
)
return final_val
if __name__ == "__main__":
arg = 0.5
loop(arg)
print(jax.grad(loop, argnums=(0,))(arg)) # works fine
print(jax.jacrev(loop, argnums=(0,))(arg)) # error: no batching support
print(jax.jacfwd(loop, argnums=(0,))(arg)) # error: no batching support
Looks like jacfwd/revs call vmap no matter what, where as grad is a jacrev assuming a scalar output no matter what and thereby bypasses the vmap.
Most helpful comment
The addition of scan() is great! I anticipate its usage will be much wider if one can call vmap() over scanned functions. It's not yet implemented.
NotImplementedError: Batching rule for 'scan' not implementedThank you!