An example where np.einsum is slower than manual matmul/transpositions. (#1966 works equally fast for me, but this example is consistently slower) on CPU and GPU.
import jax.numpy as np
import jax.random as random
from jax.api import jit
a = random.normal(random.PRNGKey(1), (100, 20, 20, 3))
b = random.normal(random.PRNGKey(2), (200, 20, 20, 3))
@jit
def matmul(a, b):
return np.transpose(np.matmul(np.transpose(a, axes=(1, 2, 0, 3)), np.transpose(b, axes=(1, 2, 3, 0))), axes=(2, 3, 0, 1))
@jit
def einsum(a, b):
return np.einsum('nxyc,mxyc->nmxy', a, b, optimize=True)
np.sum(np.abs(einsum(a, b) - matmul(a, b)))
%timeit einsum(a, b).block_until_ready()
%timeit matmul(a, b).block_until_ready()
Also note that if you run it on CPU, the difference between the method outputs becomes non-zero
DeviceArray(0.01003271, dtype=float32) - not sure how concerning it is.
FYI, I have revisited the example below on:
1) CPU: einsum is slow AND wrong: https://colab.research.google.com/gist/romanngg/48fb8d4d3a3fb5da9be84d8d1fb862ad/einsum_is_wrong_and_slow_cpu.ipynb
2) GPU: einsum is slow: https://colab.research.google.com/gist/romanngg/dd1e2adbda90749f140012f1b9342353/einsum_is_slow_gpu.ipynb
3) TPU: einsum is OK!
https://colab.research.google.com/gist/romanngg/635b467426bd9ead276cc6f9216ed03d/einsum_is_ok_tpu.ipynb
Will file bugs agains XLA:CPU and XLA:GPU!
@romanngg
Curious about progress.
Also, difference in CPU is quite small (0.01 after taking sum over all elements). That's imprecision but not an error
Haven't heard anything back yet
I think "wrong" is an overstatement here. In floating point arithmetic, two different ways of computing the same results are not guaranteed to exactly agree. That is particular true for heavily optimized routines such as matrix multiplication.
Note, for example, that NumPy's einsum gives an even more different result here:
>>> np.sum(np.abs(np.einsum('nxyc,mxyc->nmxy', a, b) - jnp.einsum('nxyc,mxyc->nmxy', a, b)))
0.3189874
If you look at the implementation of these two functions (matmul vs einsum), even though they are calculating the same thing (in principle) they are calculating it differently:
>>> jax.make_jaxpr(einsum)(a, b)
{ lambda ; a b.
let c = xla_call[ backend=None
call_jaxpr={ lambda ; a b.
let c = xla_call[ backend=None
call_jaxpr={ lambda ; a b.
let c = dot_general[ dimension_numbers=(((3,), (3,)), ((1, 2), (1, 2)))
precision=None ] b a
d = transpose[ permutation=(3, 2, 0, 1) ] c
in (d,) }
device=None
donated_invars=(False, False)
name=_einsum ] a b
in (c,) }
device=None
donated_invars=(False, False)
name=einsum ] a b
in (c,) }
>>> jax.make_jaxpr(matmul)(a, b)
{ lambda ; a b.
let c = xla_call[ backend=None
call_jaxpr={ lambda ; a b.
let c = transpose[ permutation=(1, 2, 0, 3) ] a
d = transpose[ permutation=(1, 2, 3, 0) ] b
e = dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))
precision=None ] c d
f = transpose[ permutation=(2, 3, 0, 1) ] e
in (f,) }
device=None
donated_invars=(False, False)
name=matmul ] a b
in (c,) }
XLA is usually pretty good about picking a good way to implement matrix multiplication, but it's not making the best choice here on CPU/GPU without your manual transposes. Those are definitely good opportunities for further improvement.
To be clear, the issue here isn't that einsum itself is slow, which as you can see generates quite reasonable code. This is an indictment of XLA's DotGeneral (which again, usually does pretty well).
Most helpful comment
FYI, I have revisited the example below on:
1) CPU: einsum is slow AND wrong: https://colab.research.google.com/gist/romanngg/48fb8d4d3a3fb5da9be84d8d1fb862ad/einsum_is_wrong_and_slow_cpu.ipynb
2) GPU: einsum is slow: https://colab.research.google.com/gist/romanngg/dd1e2adbda90749f140012f1b9342353/einsum_is_slow_gpu.ipynb
3) TPU: einsum is OK!
https://colab.research.google.com/gist/romanngg/635b467426bd9ead276cc6f9216ed03d/einsum_is_ok_tpu.ipynb
Will file bugs agains XLA:CPU and XLA:GPU!