Currently, vmap seems do not support a function which involves lax.while_loop. For example, the following script will throw NotImplementedError: Batching rule for 'while' not implemented:
import jax.numpy as np
from jax import jit, lax, vmap
@jit
def f(x):
return lax.while_loop(lambda x: x <= 0, lambda x: x + 1, x)
g = vmap(f)
y = g(np.arange(3.))
In case it is complicated to support while_loop for vmap, then could someone let me know other alternatives? I know some other alternatives from gufuncs tutorial such as using onp.vectorize or python for loop but they are slow. To make g fast, can I write f in C/Cython code and use vmap with it? Or should I write the whole g in C/Cython code? I have little experience with C/Cython but will try to learn if that is the only option.
I just want to use g as a primitive function and will define jvp_rule for it separately.
Thank you for any help in advance!
Thanks for requesting this!
No need for C/Cython; we can add a batching rule for lax.while_loop. In general there are some rules we haven't gotten to implementing yet, and we tend to wait for users to ask for them as a natural way to prioritize our work.
Thank you @mattjj ! That's would be a nice feature. I have been playing with JAX for a few days and really love it.
Is it correct to assume that a batched while_loop would still require a cond_fun with scalar boolean output or do you intend to support a batched cond_fun that outputs a boolean for each sample in the batch for which the loop should be continued?
That鈥檚 a great question! I was thinking the latter, but I haven鈥檛 thought through it all the way. If we wanted the whole loop to be vectorized then we should turn the condition into an np.all check along the batch dimension and use a mask to keep updating only those loops that are still running.
Thoughts?
That would be great!
This is one of the reasons that I wasn't able to disentangle vectorization and masking in Matchbox; it might be reasonable to allow vmap of while (where the condition contains a vmapped dimension) to invoke the masking interpreter when that's ready.
@jekbradbury Hmm perhaps, but in this case couldn't the masking stay at the level of the while loop primitive, rather than being pushed down into the body function (like a JAX interpreter would do)?
@jonasrauber @fehiepsi Take a look at #485 and weigh in if you have comments! Will try to get that merged soon.
Thanks @mattjj ! I have a bunch of code which needs this feature. I'll check if #485 passes edge cases in my code and will get back to you soon. :)
You probably need the fix in #486 too! Both merged into master now.
@mattjj Though #486 solves some shape errors, it seems that it still does not solve the TypeError as in the following script (which generates a random number in the interval [0, 0.5]):
from jax import vmap, random, lax
def f(key):
def body_fn(uk):
key = uk[1]
u = random.uniform(key, ())
key, _ = random.split(key)
return u, key
u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key))
#u = random.uniform(key, ()) # this is fine
return u
print(f(random.PRNGKey(0))) # no error
print(vmap(f)(random.split(random.PRNGKey(0), 2))) # TypeError: 'NoneType' object cannot be interpreted as an integer
The script runs well without using vmap or using while_loop, so I think that the error is caused by the combination of them. I guess there is some edge case which is not resolved yet.