Likely I'm confused, but I recently tracked a bug in my program to a different output when I run (what I expect to be) the same computation with vmap vs. map.
This was run on TPUv2.
import functools
import jax
import jax.numpy as jnp
print(jax.local_device_count())
# In the code below, pts only exists to allow us to execute with pmap.
def map_version(qs, pts):
return jax.lax.map(lambda x: func(x, pts), qs)
def vmap_version(qs, pts):
return jax.vmap(func, in_axes=(0, None))(qs, pts)
def func(q, pts):
"""Returns q and also the view of q that the pmap'd lambda gets"""
q_from_pmap = jax.pmap(lambda x, y: y, in_axes=(0, None))(pts, q)
return q, q_from_pmap
device_count = 2
pts = jnp.ones(device_count)
qs = jnp.asarray(((0,0), (3,3), (2,2)))
print(f'qs.shape: {qs.shape} qs:\n{qs}')
print(pts.shape: {pts.shape} {pts}')
print(f'map version\n-----------')
q, q_from_pmap = map_version(qs, pts)
print(f'q from func:{q.shape}\n{q}')
print(f'q_from_pmap:{q_from_pmap.shape}\n{q_from_pmap}')
print(f'vmap version\n-----------')
q, q_from_pmap = vmap_version(qs, pts)
print(f'q from func:{q.shape}\n{q}')
print(f'q_from_pmap:{q_from_pmap.shape}\n{q_from_pmap}')
Output:
8
qs.shape: (3, 2) qs:
[[0 0]
[3 3]
[2 2]]
points.shape: (2,) [1. 1.]
map version
-----------
q from func:(3, 2)
[[0 0]
[3 3]
[2 2]]
q_from_pmap:(3, 2, 2)
[[[0 0]
[0 0]]
[[3 3]
[3 3]]
[[2 2]
[2 2]]]
vmap version
-----------
q from func:(3, 2)
[[0 0]
[3 3]
[2 2]]
q_from_pmap:(2, 2, 3)
[[[0 3 2]
[0 3 2]]
[[0 3 2]
[0 3 2]]]
The map version does exactly what I would expect. The vmap version does not.
If I had to guess, it would be that vmap does some sort of virtual slicing, but that slicing is not maintained once qs is passed through to the pmapped lambda. The pmapped lambda seems to get all the rows of qs, when I would expect it to only get one row per batch.
Please let me know if I'm abusing vmap here, or if it's a bug.
Update
If we transpose qs in vmap_version, everything works:
def vmap_version(qs, pts):
return jax.vmap(func, in_axes=(0, None))(qs.T, pts)
This is true for my actual program as well, which is doing quite a bit more sophisticated work than this example. Doesn't feel right. :)
I think perhaps in #1959 (and the follow-up fix #2828) we neglected to handle mapped_invars correctly for vmap-of-pmap. That is, BatchTracer.process_map still assumes all arguments are mapped over by a map primitive (i.e. by pmap).
cc @gnecula @jekbradbury
Thanks so much for catching this, and for the clear repro! #3439 should fix it.
However, I'm concerned that we don't have good, systematic test coverage here. I think vmap-of-pmap has failed on us a few times and it's fallen to users to point out the bugs. For that reason I filed #3440 to add more systematic tests of vmap-of-pmap.
Most helpful comment
Thanks so much for catching this, and for the clear repro! #3439 should fix it.
However, I'm concerned that we don't have good, systematic test coverage here. I think vmap-of-pmap has failed on us a few times and it's fallen to users to point out the bugs. For that reason I filed #3440 to add more systematic tests of vmap-of-pmap.