Jax: Performance issue with sparse dot product using jit/GPU

Created on 11 Jul 2020  路  2Comments  路  Source: google/jax

I am trying to compute the dot product between a sparse matrix and a vector.

The sparse matrix is represented by its row indexes, column indexes, and associated data (COO format). While everything performs as expected when using jit/CPU, computational times become unreasonable when using GPUs. Here is the gist:

https://gist.github.com/romanodev/c2c2fbfc1e788d3fd5eeeb44803b6761

This is my first approach to JAX, so I wouldn't be surprised if there was a working solution already.

performance question

Most helpful comment

@jekbradbury, fantastic. Thanks!

As a reference, this is the updated version:

https://gist.github.com/romanodev/34a471e8914989dbedee85febdbe7c77

All 2 comments

Tight loops implemented using lax.fori_loop, lax.while_loop, or lax.scan tend to be very inefficient on GPU, since the best the compiler can currently do is to launch one CUDA kernel per loop iteration. Since your loop is performing a single scalar update per step, but CUDA kernels have about 5 microseconds of per-launch overhead, that overhead dominates useful compute.

Luckily, in your case, the loop can be implemented with NumPy advanced indexing (which JAX has slightly nonstandard functional syntax for) instead, resulting in just one or two compiled CUDA kernels (a gather plus a scatter) and much faster timing:

@jax.jit
def dot_product2(rows,cols,data,b):

    updates = data * b[cols]
    return jnp.zeros_like(b).at[rows].add(updates)
>>> %timeit dot_product2(rows,cols,data,b).block_until_ready()
10000 loops, best of 3: 130 碌s per loop

@jekbradbury, fantastic. Thanks!

As a reference, this is the updated version:

https://gist.github.com/romanodev/34a471e8914989dbedee85febdbe7c77

Was this page helpful?
0 / 5 - 0 ratings