Jax: Einsum is slow

Created on 4 Feb 2020  路  5Comments  路  Source: google/jax

An example where np.einsum is slower than manual matmul/transpositions. (#1966 works equally fast for me, but this example is consistently slower) on CPU and GPU.

https://colab.research.google.com/gist/romanngg/e63834765d00497e315455867a52eae1/einsum_is_slow.ipynb

import jax.numpy as np
import jax.random as random
from jax.api import jit

a = random.normal(random.PRNGKey(1), (100, 20, 20, 3))
b = random.normal(random.PRNGKey(2), (200, 20, 20, 3))

@jit
def matmul(a, b):
  return np.transpose(np.matmul(np.transpose(a, axes=(1, 2, 0, 3)), np.transpose(b, axes=(1, 2, 3, 0))), axes=(2, 3, 0, 1))

@jit
def einsum(a, b):
  return np.einsum('nxyc,mxyc->nmxy', a, b, optimize=True)

np.sum(np.abs(einsum(a, b) - matmul(a, b)))

%timeit einsum(a, b).block_until_ready()

%timeit matmul(a, b).block_until_ready()

Also note that if you run it on CPU, the difference between the method outputs becomes non-zero
DeviceArray(0.01003271, dtype=float32) - not sure how concerning it is.

Most helpful comment

FYI, I have revisited the example below on:

1) CPU: einsum is slow AND wrong: https://colab.research.google.com/gist/romanngg/48fb8d4d3a3fb5da9be84d8d1fb862ad/einsum_is_wrong_and_slow_cpu.ipynb
2) GPU: einsum is slow: https://colab.research.google.com/gist/romanngg/dd1e2adbda90749f140012f1b9342353/einsum_is_slow_gpu.ipynb
3) TPU: einsum is OK!
https://colab.research.google.com/gist/romanngg/635b467426bd9ead276cc6f9216ed03d/einsum_is_ok_tpu.ipynb

Will file bugs agains XLA:CPU and XLA:GPU!

All 5 comments

FYI, I have revisited the example below on:

1) CPU: einsum is slow AND wrong: https://colab.research.google.com/gist/romanngg/48fb8d4d3a3fb5da9be84d8d1fb862ad/einsum_is_wrong_and_slow_cpu.ipynb
2) GPU: einsum is slow: https://colab.research.google.com/gist/romanngg/dd1e2adbda90749f140012f1b9342353/einsum_is_slow_gpu.ipynb
3) TPU: einsum is OK!
https://colab.research.google.com/gist/romanngg/635b467426bd9ead276cc6f9216ed03d/einsum_is_ok_tpu.ipynb

Will file bugs agains XLA:CPU and XLA:GPU!

@romanngg

Curious about progress.
Also, difference in CPU is quite small (0.01 after taking sum over all elements). That's imprecision but not an error

Haven't heard anything back yet

I think "wrong" is an overstatement here. In floating point arithmetic, two different ways of computing the same results are not guaranteed to exactly agree. That is particular true for heavily optimized routines such as matrix multiplication.

Note, for example, that NumPy's einsum gives an even more different result here:

>>> np.sum(np.abs(np.einsum('nxyc,mxyc->nmxy', a, b) - jnp.einsum('nxyc,mxyc->nmxy', a, b)))
0.3189874

If you look at the implementation of these two functions (matmul vs einsum), even though they are calculating the same thing (in principle) they are calculating it differently:

>>> jax.make_jaxpr(einsum)(a, b)
{ lambda  ; a b.
  let c = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a b.
                                 let c = xla_call[ backend=None
                                                   call_jaxpr={ lambda  ; a b.
                                                                let c = dot_general[ dimension_numbers=(((3,), (3,)), ((1, 2), (1, 2)))
                                                                                     precision=None ] b a
                                                                    d = transpose[ permutation=(3, 2, 0, 1) ] c
                                                                in (d,) }
                                                   device=None
                                                   donated_invars=(False, False)
                                                   name=_einsum ] a b
                                 in (c,) }
                    device=None
                    donated_invars=(False, False)
                    name=einsum ] a b
  in (c,) }

>>> jax.make_jaxpr(matmul)(a, b)
{ lambda  ; a b.
  let c = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a b.
                                 let c = transpose[ permutation=(1, 2, 0, 3) ] a
                                     d = transpose[ permutation=(1, 2, 3, 0) ] b
                                     e = dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))
                                                      precision=None ] c d
                                     f = transpose[ permutation=(2, 3, 0, 1) ] e
                                 in (f,) }
                    device=None
                    donated_invars=(False, False)
                    name=matmul ] a b
  in (c,) }

XLA is usually pretty good about picking a good way to implement matrix multiplication, but it's not making the best choice here on CPU/GPU without your manual transposes. Those are definitely good opportunities for further improvement.

To be clear, the issue here isn't that einsum itself is slow, which as you can see generates quite reasonable code. This is an indictment of XLA's DotGeneral (which again, usually does pretty well).

Was this page helpful?
0 / 5 - 0 ratings

Related issues

sschoenholz picture sschoenholz  路  3Comments

lonelykid picture lonelykid  路  3Comments

rdaems picture rdaems  路  3Comments

yfji picture yfji  路  3Comments

sursu picture sursu  路  3Comments