Hi JAX Team,
I'm interested in using your library, e.g. to use gradients to optimize functions whose
terms include functions g() of jacobians of functions f(). There's no good way in PyTorch to do this for f() with output dim > 1, and certainly no good way to do this in a batched setting.
In JAX, i'm able to very easily get e.g. the trace of the jacobian of a plain relu neural net. So cool! Things break down when I try to use vmap to get a batched version of the computation. I'm wondering, do you have a rough guess at the timeline for when I might be able to vmap in this kind of setting? Note: I've successfully used JAX to do vmap's of functions g() of jacobians of functions f() for simple choices of f()
import jax.numpy as np
from jax import jit, jacrev, vmap
W = np.eye(3)
b = np.zeros(3)
def relu(x):
return np.maximum(0,x)
def NN(x):
return relu(np.dot(W,x) + b)
Jx_NN = jit(jacrev(NN))
# By the way, is there a recommended way to compose already JIT'ed functions?
def trace_Jx_NN(x):
J = Jx_NN(x)
return np.trace(J)
x1 = np.array([2.0,2.0,2.0])
x2 = np.array([3.0,3.0,3.0])
X = np.vstack((x1,x2))
# this works
print(trace_Jx_NN(x1))
# this line executes
batched_trace_Jx_NN = jit(vmap(trace_Jx_NN,in_axes=(0)))
# raises "NotImplementedError # TODO(schsam, mattjj): Handle more cases."
print(batched_trace_Jx_NN(X))
Thanks!
mark
Dev timeline! What kind of professional and organized operation do you think we're running here? :)
More seriously, thanks for your interest, and the kind words about JAX! To answer your question, our work is driven primarily in two ways: first, we're pushing forward the core system as a research project (e.g. with new function transformations for parallel programming and automatic masking), and second, we're extending the system's coverage, adding miscellaneous features and docs, and squashing bugs. There's an approximately infinite amount of work we could do in the latter category, and so we prioritize that work based on what users ask for.
In other words, _you_ drive our timeline, along with all the other users kind enough to open GitHub issues and provide feedback.
Even better, you can join us and help accomplish these goals by contributing code. One of the strengths of JAX is it's a pretty small pure-Python system, so once you get over the hump of learning about the codebase, it should be relatively easy for any Python programmer to contribute. (But one of the weaknesses of JAX is that we haven't yet written down enough documentation to help new developers get on board, so that initial hump is much higher than it needs to be. We're working on that!)
On your specific question about this batching rule (as we call the primitive-by-primitive transformation rules for vmap), let's consider this issue a feature request for it. @sschoenholz do you have a sense for how tricky it would be to cover the case needed by this issue? I believe this is the error being raised.
If you wanted to dig in, that line is where you'd want to start adding code. Here's a quick attempt at explaining what's going on. That function, _select_batch_rule, is basically responsible for dealing with a lax.select in the context of a vmap transformation. The game is this: the user code called lax.select (maybe indirectly through a numpy function, in this case np.maximum), blissfully ignorant that, because of a vmap, there was an extra batch dimension lurking behind the scenes. But someone has to deal with that extra dimension when it comes time to evaluate the lax.select call, and that someone is _select_batch_rule. To do that, _select_batch_rule is given information about what it has to deal with: the full arguments, with any batch dimensions exposed, and their corresponding batch dimensions, represented as an integer (indicating which dimension/axis is the one hidden from the lax.select being called) or a None if that argument doesn't have a hidden dimension after all.
It looks like we've only implemented the logic for the cases where the batch dimensions corresponding to on_true and on_false are either None (i.e. no hidden dimension to deal with) or equal to pred_dim. One easy way to handle the other cases might be just to transpose those dimensions to the front; I suspect we have to do something like that anyway.
I edited the original post to add these lines to the top of the code, making it a full repro:
import jax.numpy as np
from jax import jit, jacrev, vmap
W = np.eye(3)
b = np.zeros(3)
Please revise further if that edit was misguided.
lax.select, including this issue, and we can probably merge it in the next hour. That's a pretty good dev timeline!(I should add another commit or two with more test cases, and maybe more efficient transpose-avoiding lowerings for special cases, like the special case we had covered before.)
Oh no! Actually, I've never worked in industry, so I didn't realize the "dev timeline" phrasing had such connotations, sorry! I just meant to ask, is it something you had particular plans about :)
Yes, your completion of my code example was correct. I will be sure to post a fully-reproducible example next time.
Thanks so much for pursuing this case! Your message above about handling batching clarifies some things in the code for me, and I'll try to understand it a bit more.
More generally, I'm looking forward to reading the code base more in the next couple of weeks and contributing where possible. I'm doing grad school with a PL advisor and stat/ML advisor, and was thinking JAX is a good place to find some cool research problems!
That's awesome to hear. We'd love JAX to be a platform for that kind of work. I feel like we have ML+PL research problems coming out of our ears, and I'm sure there are many more to find.