Jax: Please enable vmap(scan())

Created on 15 May 2019  路  3Comments  路  Source: google/jax

enhancement

Most helpful comment

The addition of scan() is great! I anticipate its usage will be much wider if one can call vmap() over scanned functions. It's not yet implemented.
NotImplementedError: Batching rule for 'scan' not implemented

Thank you!

All 3 comments

The addition of scan() is great! I anticipate its usage will be much wider if one can call vmap() over scanned functions. It's not yet implemented.
NotImplementedError: Batching rule for 'scan' not implemented

Thank you!

+1 on this too! I can get around this when doing jacrev for the time being but not jacfwd, but jacrev has other problems :)

To add on to this, I'm trying to better understand when batching is actually needed:

import jax
import jax.numpy as jnp

def fn(a):
    return jnp.cos(a)

def loop(val):

    iterations = 10000

    def apply_carry(x, i):
        return fn(x), i

    final_val, _ = jax.lax.scan(
        apply_carry,
        val,
        jnp.arange(iterations)
    )

    return final_val

if __name__ == "__main__":
    arg = 0.5
    loop(arg)
    print(jax.grad(loop, argnums=(0,))(arg)) # works fine
    print(jax.jacrev(loop, argnums=(0,))(arg)) # error: no batching support 
    print(jax.jacfwd(loop, argnums=(0,))(arg)) # error: no batching support

Looks like jacfwd/revs call vmap no matter what, where as grad is a jacrev assuming a scalar output no matter what and thereby bypasses the vmap.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

alexbw picture alexbw  路  26Comments

samuela picture samuela  路  27Comments

JuliusKunze picture JuliusKunze  路  23Comments

murphyk picture murphyk  路  31Comments

NeilGirdhar picture NeilGirdhar  路  23Comments