Dear jax team,
I'm having trouble using vmap over multiple axes for functions taking multiple arguments. I.e. nested tuples in in_axes (and possibly out_axes). The docstring mentions them
in_axes: Specifies which input axes to map over. These may be integers,
None, or (possibly nested) tuples of integers orNone.
but I can't get it to work.
import jax
import jax.numpy as np
import numpy as onp
solve = jax.vmap(
np.linalg.solve,
in_axes=((2, 3), (1, 2)),
out_axes=(1, 2))
A = np.array(onp.random.rand(3, 3, 100, 100), np.float32)
b = np.array(onp.random.rand(3, 100, 100), np.float32)
s12 = solve(A, b) # TypeError: <class 'tuple'>
Am I using vmap wrong? If yes, it this explained somewhere?
Two overall points about vmap.
So, in_axes specifies - for each node in the pytree - which axis is being vmapped over. Note that if you specify an in_axes for a non-leaf note in the pytree it will apply that batch axis to all child nodes.
In your case, you want to add two batch dimensions to A and one to b. Therefore, you would want two applications of vmap as e.g.
import jax
import jax.numpy as np
import numpy as onp
solve = jax.vmap(jax.vmap(np.linalg.solve, 0, 0), (0, None), 0)
A = np.array(onp.random.rand(3, 3, 100, 100), np.float32)
b = np.array(onp.random.rand(3, 100, 100), np.float32)
s12 = solve(A, b)
Thank you @sschoenholz, that does the trick. But if
Each call to vmap adds a single batch dimension
why does the docstring mention nested tuples?
FYI, the nested vmap for my problem turns out to be
solve = jax.vmap(jax.vmap(np.linalg.solve, (2, 1), 1), (3, 2), 2)
as I'm mapping over the last axes, not the front ones.
The main use-case for nested tuples is if your function takes a pytree. For example,
# Compute a single pre-activation neuron.
# Here params is a tuple of two values: a weight and a bias.
def one_neuron(params, x):
w, b = params
return np.dot(w, x) + b
# We can map this over a whole layer of neurons for a single input.
layer = vmap(one_neuron, (0, None), 0)
# We can alternatively map of a whole layer of neurons with a shared
# bias for each neuron.
layer_shared_bias = vmap(one_neuron, ((0, None), None), 0)
# Finally, if we want to we can map the layer over a batch of inputs.
batched_layer = vmap(layer, (None, 0), 0)
# Or with shared bias.
batched_layer_shared_bias = vmap(layer_shared_bias, (None, 0), 0)
Thank you for the clarification :+1:
Most helpful comment
Two overall points about vmap.
So, in_axes specifies - for each node in the pytree - which axis is being vmapped over. Note that if you specify an in_axes for a non-leaf note in the pytree it will apply that batch axis to all child nodes.
In your case, you want to add two batch dimensions to
Aand one tob. Therefore, you would want two applications of vmap as e.g.