Jax: No vmap over multiple axes for multiple inputs

Created on 12 Aug 2019  路  4Comments  路  Source: google/jax

Dear jax team,

I'm having trouble using vmap over multiple axes for functions taking multiple arguments. I.e. nested tuples in in_axes (and possibly out_axes). The docstring mentions them

in_axes: Specifies which input axes to map over. These may be integers, None, or (possibly nested) tuples of integers or None.

but I can't get it to work.

import jax
import jax.numpy as np
import numpy as onp

solve = jax.vmap(
    np.linalg.solve, 
    in_axes=((2, 3), (1, 2)), 
    out_axes=(1, 2))
A = np.array(onp.random.rand(3, 3, 100, 100), np.float32)
b = np.array(onp.random.rand(3, 100, 100), np.float32)

s12 = solve(A, b)   # TypeError: <class 'tuple'>

Am I using vmap wrong? If yes, it this explained somewhere?

Most helpful comment

Two overall points about vmap.

  1. Each call to vmap adds a single batch dimension.
  2. The arguments to vmap can be any pytrees.

So, in_axes specifies - for each node in the pytree - which axis is being vmapped over. Note that if you specify an in_axes for a non-leaf note in the pytree it will apply that batch axis to all child nodes.

In your case, you want to add two batch dimensions to A and one to b. Therefore, you would want two applications of vmap as e.g.

import jax
import jax.numpy as np
import numpy as onp

solve = jax.vmap(jax.vmap(np.linalg.solve, 0, 0), (0, None), 0)
A = np.array(onp.random.rand(3, 3, 100, 100), np.float32)
b = np.array(onp.random.rand(3, 100, 100), np.float32)

s12 = solve(A, b) 

All 4 comments

Two overall points about vmap.

  1. Each call to vmap adds a single batch dimension.
  2. The arguments to vmap can be any pytrees.

So, in_axes specifies - for each node in the pytree - which axis is being vmapped over. Note that if you specify an in_axes for a non-leaf note in the pytree it will apply that batch axis to all child nodes.

In your case, you want to add two batch dimensions to A and one to b. Therefore, you would want two applications of vmap as e.g.

import jax
import jax.numpy as np
import numpy as onp

solve = jax.vmap(jax.vmap(np.linalg.solve, 0, 0), (0, None), 0)
A = np.array(onp.random.rand(3, 3, 100, 100), np.float32)
b = np.array(onp.random.rand(3, 100, 100), np.float32)

s12 = solve(A, b) 

Thank you @sschoenholz, that does the trick. But if

Each call to vmap adds a single batch dimension

why does the docstring mention nested tuples?


FYI, the nested vmap for my problem turns out to be

solve = jax.vmap(jax.vmap(np.linalg.solve, (2, 1), 1), (3, 2), 2)

as I'm mapping over the last axes, not the front ones.

The main use-case for nested tuples is if your function takes a pytree. For example,

# Compute a single pre-activation neuron. 
# Here params is a tuple of two values: a weight and a bias.
def one_neuron(params, x):
  w, b = params
  return np.dot(w, x) + b

# We can map this over a whole layer of neurons for a single input.
layer = vmap(one_neuron, (0, None), 0) 
# We can alternatively map of a whole layer of neurons with a shared 
# bias for each neuron.
layer_shared_bias = vmap(one_neuron, ((0, None), None), 0)

# Finally, if we want to we can map the layer over a batch of inputs.
batched_layer = vmap(layer, (None, 0), 0)
# Or with shared bias.
batched_layer_shared_bias = vmap(layer_shared_bias, (None, 0), 0)

Thank you for the clarification :+1:

Was this page helpful?
0 / 5 - 0 ratings

Related issues

shoyer picture shoyer  路  24Comments

dionhaefner picture dionhaefner  路  22Comments

ricardobarroslourenco picture ricardobarroslourenco  路  35Comments

shyoshyo picture shyoshyo  路  26Comments

shoyer picture shoyer  路  35Comments