Jax: jit `lax.scatter_add` does not improve performance in CPU

Created on 10 May 2019  路  12Comments  路  Source: google/jax

In CPU, the following script

import numpy as onp
import jax.numpy as np
from jax import jit, lax
from jax.config import config; config.update("jax_platform_name", "cpu")

@jit
def vector_to_tril_matrix(t):
    idx = np.reshape(np.arange(79 * 79), (79, 79))[onp.tril_indices(79, 0)]
    x = lax.scatter_add(np.zeros((t.shape[0], 79 * 79)), np.expand_dims(idx, axis=-1), t,
                        lax.ScatterDimensionNumbers(update_window_dims=range(t.ndim - 1),
                                                    inserted_window_dims=(t.ndim - 1,),
                                                    scatter_dims_to_operand_dims=(t.ndim - 1,)))
    return np.reshape(x, (-1, 79, 79))

%time vector_to_tril_matrix(np.ones((8000, 3160)))
%time vector_to_tril_matrix(np.ones((8000, 3160)))

returns 5.61 s and 5.28 s,
while in GPU the script returns 631 ms and 9.29 ms.

Because the difference of pre-cached and after-cached are large in GPU, I would expect the same happens for CPU but it seems not be the case.

However, testing for a smaller batch of input in CPU, we can see that jit works. For example,

%time vector_to_tril_matrix(np.ones((8, 3160)))
%time vector_to_tril_matrix(np.ones((8, 3160)))

returns 254 ms and 786 碌s, which is expected.

Because the shape is decreased by 1000, I would expect the speed of vector_to_tril_matrix(np.ones((8000, 3160))) would be around 786 碌s * 1000 = 786 ms (or much smaller than that if vectorization works here) but the first test shows that it is not the case.

enhancement

Most helpful comment

The commit mentioned above should greatly increase the speed of this computation. I observed ~20x faster on my machine.

Note that you may still observe the fact that you get a nonlinear slowdown as you increase the computation's size. That is, if the tensors are 10x bigger, the computation may be 100x slower, or worse. This is a fundamental property of CPUs, caches, etc.

Thank you for reporting this bug!

All 12 comments

I will note that on my desktop I seem to get a substantially faster result:

CPU times: user 1.48 s, sys: 167 ms, total: 1.65 s
Wall time: 789 ms

That said I agree those times aren't great!

My guess is that this simply says that XLA's scatter implementation on CPU is slow, and the JIT compilation time is negligible in comparison to the scatter itself. From memory it's currently a fairly naive reference implementation on CPU, whereas the implementation on GPU is more optimized. I'll file a feature request with the XLA folks.

Thanks Peter! Looking like it takes much effort for XLA devs to support various platforms because the implementations are done separately.

The implementations do share a lot of code, but GPU is often higher priority than CPU because that's what most deep learning users want.

In this case, I think no-one has ever gotten around to adding a CPU-specific lowering for scatter; it's actually expanded into a loop at the HLO level by this code:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/scatter_expander.h

The GPU code started out the same way, but eventually someone added a better implementation that gets better GPU utilization.

Looking like it takes much effort for XLA devs to support various platforms because the implementations are done separately.

Yes, but the promise of XLA is to solve that problem once and for all and remove it as a concern for its users.

That is, systems like TF often need to implement separate CPU and GPU versions of many ops (and those systems often have O(hundreds) or O(thousands) of ops). XLA aims to solve that problem once and for all by having a relatively small number of ops (like O(dozens)) and being able to generate code for the ops and their compositions for multiple different backends. That makes it super easy to bring up a new front-end system like JAX that can immediately target CPU or GPU (or TPU!). But the XLA team still has to solve the hard problem of generating different instructions and kernels for different platforms.

So XLA devs are solving the hard multi-device problems for us (and other libraries like JAX), and moreover their task is much better scoped because they rely on a small number of ops and compositionality.

This is getting tangential, but on the subject of leveraging XLA to bring up new front-end systems easily, take a look at the XLA in Python notebook.

Yup, I can imagine that situation. That said how lucky I am to be able to use JAX seamlessly accross CPU/GPU because much of the hard job are done by JAX/XLA devs. :) Thanks a lot!

The commit mentioned above should greatly increase the speed of this computation. I observed ~20x faster on my machine.

Note that you may still observe the fact that you get a nonlinear slowdown as you increase the computation's size. That is, if the tensors are 10x bigger, the computation may be 100x slower, or worse. This is a fundamental property of CPUs, caches, etc.

Thank you for reporting this bug!

Woohoo XLA!

(As usual we'll need to update jaxlib, or else users will need to rebuild jaxlib from source, to get an upgraded XLA.)

For posterity, IIUC @jlebar also pointed out this parent commit as responsible for the 20x (!!) speedup he just gave us.

Thanks @jlebar ! I am looking forward to try it soon. :)

On CPU on my Macbook, I now get:

CPU times: user 310 ms, sys: 169 ms, total: 478 ms
Wall time: 576 ms
CPU times: user 41.1 ms, sys: 4.74 ms, total: 45.8 ms
Wall time: 9.09 ms

I think we can declare this fixed! Thanks Justin!

I should add: that timing is with the newly released Jaxlib 0.1.16.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

shoyer picture shoyer  路  24Comments

JuliusKunze picture JuliusKunze  路  23Comments

ericmjl picture ericmjl  路  53Comments

christopherhesse picture christopherhesse  路  32Comments

dionhaefner picture dionhaefner  路  22Comments