Jax: Using vmap

Created on 25 Aug 2020  路  18Comments  路  Source: google/jax

Hi,

I want to compare the performance of using jit vs vmap, but I don't quite understand how to set the arguments for vmap.

My function expects 3 argments (array1, array2, par), where array1 is 3 dimensional, array2 is 1 dimension and par is a dictionary of parameters. The operation is performed over the 1st dimension of array1, and array2 and par are static.

## Using jit
dd = np.zeros((in1.shape[0],3,100),dtype=np.float32)
func_jit = jit(myFunc, static_argnums=[1,2])
for i in range(len(in1)):
    dd[i]=func_jit(in1,in2,params)

## Using vmap  (but this returns an error "tuple index out of range")
dd = np.zeros((in1.shape[0],3,100),dtype=np.float32)
func_vmap = vmap(myFunc, (0, None, None))
dd = EEI_vmap(in1,in2,params)

Thanks!

question

Most helpful comment

@jakevdp Can you give some insight into how one should choose between these two options?

The general rule is "jit at the outer-most level", i.e., outside of vmap and grad. That caches the work of tracing vmap inside jit, so it has less overhead.

All 18 comments

Hey @docyyz, thanks for the question!

(You may already know this, but for the best performance you likely want to use jit _and_ vmap together.)

Can you provide a bit more info, like the traceback you're seeing or a toy version of myFunc?

You can think of vmap(f, (0, None, None)(xs, a, b) as meaning this:

np.stack([f(x, a, b) for x in xs])

Notice that by iterating for x in xs we're un-stacking along the first axis (i.e. axis index 0) of xs.

Hi @mattjj , thanks for your response. Even after reading the docstring for vmap, I'm still confused about what I need for the in_axes and out_axes arguments in my case. The myFunc function basically does a convolution of in1[i] and in2. Sorry there was a mistake in my sample code, the fifth line should read:
dd[i]=func_jit(in1[i],in2,params)

So in the loop, i'm slicing _in1_ along the first axis and doing some mathmetical operations between the sliced _in1_ and _in2_ using parameters from _params_.

Would it be possible to share your code, as a runnable repro of the issue? Then we could probably sort things out very quickly!

ok, so do just need to create a github repo and share it with you?

That would work, but maybe you could minimize it enough just to paste inline in a comment here? Since this is just an API question, we probably don't need much from your project to answer it; just enough to be concrete. You could also just make up a toy example illustrating your question.

import numpy as np
import scipy as sp
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import jit, vmap

# Define my functions (main function is conv, which calls detrend)

def detrend(data, axis=-1, type='linear', bp=0):

    if type not in ['constant', 'linear', 'c', 'l']:
        raise ValueError("Trend type must be 'linear' or 'constant'.")
    data = jnp.asarray(data)

    if type in ['constant', 'c']:
        return data - jnp.mean(data, axis, keepdims=True)
    else:
        N = data.shape[axis]
        # bp is static, so we use np operations to avoid pushing to device.
        bp = np.sort(np.unique(np.r_[0, bp, N]))

    if np.any(bp > N):
        raise ValueError("Breakpoints must be less than length of data along given axis.")
    data = jnp.moveaxis(data, axis, 0)
    shape = data.shape
    data = data.reshape(N, -1)

    for m in range(len(bp) - 1):
        Npts = bp[m + 1] - bp[m]
        A = jnp.vstack([jnp.ones(Npts), jnp.arange(1, Npts + 1) / Npts]).T
        sl = slice(bp[m], bp[m + 1])
        coef, resids, rank, s = jnp.linalg.lstsq(A, data[sl])
        data = data.at[sl].add(-jnp.dot(A, coef))

    return jnp.moveaxis(data.reshape(shape), 0, axis)


def conv(in1,in2,par):

    nang = len(par['ang'])
    a0,b0,c0 = (par['alpha0'],par['beta0'],par['gamma0'])
    dcal = jnp.zeros((nang,in1.shape[-1]),dtype=np.float32)

    for i in range(nang):
        ang_rad = jnp.radians(par['ang'][i])
        k = b0**2 / a0**2
        d = jnp.cos(ang_rad) + jnp.sin(ang_rad)
        e = (-8.0) * k * jnp.sin(ang_rad)
        f = jnp.cos(ang_rad) - 4 * k * jnp.sin(ang_rad)
        trc = (a0*c0)*((in1[0]/a0)**d)*((in1[1]/b0)**e)*(in1[2]/c0)**f

        trc_filt = jsp.signal.convolve(detrend(trc),in2[i],mode='same')
        dcal = jax.ops.index_update(dcal,jax.ops.index[i,:],trc_filt)

    return dcal

# Set input (params, wavelet, and data)
params = {'ang':[0.0, 50.0], 'alpha0':2000.0, 'beta0':1000.0, 'gamma0':2.0, 'dt': 0.002}

wavelet = sp.signal.ricker(256,4.0)
wavelet = np.expand_dims(wavelet,axis=1).T
wavelet = np.pad(wavelet,((1,0),(0,0)),mode='edge')

alpha = np.random.normal(loc=params['alpha0'], scale=200.0, size=(500,256))
beta = np.random.normal(loc=params['beta0'], scale=100.0, size=(500,256))
gamma = np.abs(np.random.normal(loc=params['gamma0'], scale=0.5, size=(500,256)))
data = np.stack((alpha,beta,gamma),axis=1)

# Run it with Jax jit
dobs = np.zeros((data.shape[0],2,data.shape[-1]),dtype=np.float32)
conv_jit = jit(conv, static_argnums=[1,2])
for i in range(len(dobs)):
    dobs[i]=conv_jit(data[i],wavelet,params)

# Run it with Jax vmap / jit+vmap?

Sorry it looks a bit ugly....I just copied the code from my Notebook without doing any formatting.

I added backticks to format it.

I figured out the issue now...it's rather silly...I was slicing the first dimension of the array _data_ instead of just mapping it.

dobs = np.zeros((data.shape[0],2,data.shape[-1]),dtype=np.float32)
conv_vmap = vmap(conv, (0,None,None))
dobs = conv_vmap(data,wavelet,params)

For my example, using vmap doubles the speed compared with jit, usiing jit+vmap almost halves the speed again. Is there a difference between jit(vmap(func)) and vmap(jit(func)). Their speeds seem to be the same.

In general jit(vmap(func)) and vmap(jit(func)) may produce slightly different jaxprs under the hood. For example:

import jax.numpy as jnp                                                                                                                                                                                            
from jax import jit, vmap, random, make_jaxpr

key = random.PRNGKey(1701)
M = random.normal(key, (150, 100)) 
x = random.normal(key, (10, 100))

def apply_M(v):
  return jnp.dot(M, v)

f1 = jit(vmap(apply_M))
f2 = vmap(jit(apply_M))

%timeit f1(x).block_until_ready()
%timeit f2(x).block_until_ready()
1000 loops, best of 3: 318 碌s per loop
1000 loops, best of 3: 470 碌s per loop



md5-851d8ee84b7085bc70b04428f6540064



{ lambda b ; a.
let c = xla_call[ backend=None
call_jaxpr={ lambda ; b a.
let c = dot_general[ dimension_numbers=(((1,), (1,)), ((), ()))
precision=None ] b a
d = transpose[ permutation=(1, 0) ] c
in (d,) }
device=None
donated_invars=(False, False)
name=apply_M ] b a
in (c,) }

```python
print(make_jaxpr(f2)(x))
{ lambda b ; a.
  let c = xla_call[ backend=None
                    call_jaxpr={ lambda  ; b a.
                                 let c = dot_general[ dimension_numbers=(((1,), (1,)), ((), ()))
                                                      precision=None ] b a
                                 in (c,) }
                    device=None
                    donated_invars=(False, False)
                    name=vmap(apply_M) ] b a
      d = transpose[ permutation=(1, 0) ] c
  in (d,) }

You can see the difference in the generated jaxpr here is rather the transpose is inside or outside the mapped call. Whether this makes a significant difference probably depends on what the function is doing.

@jakevdp Can you give some insight into how one should choose between these two options?

@jakevdp Can you give some insight into how one should choose between these two options?

The general rule is "jit at the outer-most level", i.e., outside of vmap and grad. That caches the work of tracing vmap inside jit, so it has less overhead.

+1 to what @shoyer wrote. Using jit should always produce code that is faster to execute, though the first time you run a jitted function for a set of input shapes it'll have to compile the code, and that can take time. By putting jit on the outside, you're compiling more stuff.

By the way @shoyer, vmap-of-jit also caches the vmap tracing inside jit (just like grad-of-jit caches the forward and backward passes needed), but there's still some overhead from the stuff that vmap does outside the jit. (Hope that made sense...)

In general jit(vmap(func)) and vmap(jit(func)) may produce slightly different jaxprs under the hood. For example:
...
You can see the difference in the generated jaxpr here is rather the transpose is inside or outside the mapped call. Whether this makes a significant difference probably depends on what the function is doing.

@jakevdp This is really good to know.

@jakevdp @mattjj @shoyer @skye and team - what do you think about having the jit(vmap()) vs vmap(jit()) in the Common Gotchas/ 馃敧 Sharp Bits 馃敧 section? (https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)

@8bitmp3 well, throughout the documentation we certainly want to get across that if you jit bigger functions you're staging out more to XLA and thus you can expect better performance. But I'm not sure if it'd make sense to add anything vmap-specific; it's just an instance of the general jitting-more-stuff-is-faster advice. WDYT?

I think we covered this question, so I'm going to close the issue. Thanks for the discussion everyone!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

murphyk picture murphyk  路  3Comments

shannon63 picture shannon63  路  3Comments

sschoenholz picture sschoenholz  路  3Comments

DylanMuir picture DylanMuir  路  3Comments

asross picture asross  路  3Comments