Hi JAX,
I'm interested in performing dot products between two sparse matrices. Currently it looks like no accelerators support this - JAX and tensorflow have sparse matrix by dense vectors, but that's as close as I can find.
I've found how scipy does it ( https://github.com/scipy/scipy/blob/d4789babc26abcf09130b411df10eb72f8401d54/scipy/sparse/sparsetools/bsr.h#L249 ) but I'm not sure I can figure out how that would look in JAX.
However, I am about to apply for an internal grant offered at my institution that would get me some dedicated time with experienced software engineers, who I'm sure can help push this forward. I thought I would ask here first if anyone has ideas or, if it's a bad idea, to dissuade me from it so I don't waste the grant time!
Cheers
Lewis
edit: thought I would add some context. Calculating pairwise jaccard distances between two sets of sparse vectors requires calculating the (dense) pairwise matrix of intersections. One way to do this is:
mat = sparse.csr_matrix(my_data)
intersection = mat.dot(mat.T)
I linked the BSR implementation but should have referred to CSR matrices sorry, i.e. https://github.com/scipy/scipy/blob/2eeed791ce100fd2bf127967ec41759ad3b4598c/scipy/sparse/sparsetools/csr.h#L562
I see there's already some work on sparse operations: https://github.com/google/jax/pull/4422 , so perhaps I'll just wait for that to progress to CSR matmuls?
As far as I understand, there's no way to do a CSR matmul in a vectorized way with indexing because the number of entries per row changes dynamically, so the alternative is iterating through all the indices and checking for equality, which would involve a huge amount of 'if/else' and may not work well on GPUs.
Cheers :)
Hi @ljmartin – I think #4422 is the thing to watch. You're right that currently there is no XLA support for sparse operations, but that PR uses scatter/gather approaches that are supported by XLA and don't require any explicit loops or conditionals, so they should be fairly performant.
Once I iron out some of the issues with that approach, my hope is to use JAX's translation rules to target e.g. cuSparse on GPU, and to think about how to provide similar fast operations on other backends as well.
Most helpful comment
Hi @ljmartin – I think #4422 is the thing to watch. You're right that currently there is no XLA support for sparse operations, but that PR uses scatter/gather approaches that are supported by XLA and don't require any explicit loops or conditionals, so they should be fairly performant.
Once I iron out some of the issues with that approach, my hope is to use JAX's translation rules to target e.g. cuSparse on GPU, and to think about how to provide similar fast operations on other backends as well.