Hey, thanks for the great work here!
I noticed that matmuls for complex dtypes are ~ 20x to 25x slower on my macbook than they are for real dtypes. Here is a simple code that does the timing.
Wondering if this is expected, or what I can do to speed that up. Thanks!
import numpy as np
import jax
from jax import config
config.update('jax_enable_x64',True)
import time
@jax.jit
def do_matvec_simple(matrix, vector):
res = 0
for _ in range(100):
res += matrix @ vector
return res
@jax.jit
def do_matmul_simple(matrix1, matrix2):
res = 0
for _ in range(100):
res += matrix1 @ matrix2
return res
def run_timings_dot(dtype, D):
matrix = jax.numpy.array(np.random.rand(D,D).astype(dtype))
vector = jax.numpy.array(np.random.rand(D).astype(dtype))
t1=time.time()
for _ in range(100):
res = matrix @ vector
res.block_until_ready()
print(f'loop over 100 matrix-vector muls in dtype {np.dtype(dtype).name}', time.time() -t1)
res = do_matvec_simple(matrix, vector)
res.block_until_ready()
t1 = time.time()
res = do_matvec_simple(matrix, vector)
res.block_until_ready()
print(f'jit 100 do_matvec_simple for dtype {np.dtype(dtype).name}', time.time() - t1)
def run_timings_matmul_simple(dtype, D):
A = jax.numpy.array(np.random.rand(D,D).astype(dtype))
B = jax.numpy.array(np.random.rand(D,D).astype(dtype))
t1=time.time()
for _ in range(100):
res = A@B
res.block_until_ready()
print(f'loop over 100 matrix-matrix muls in dtype {np.dtype(dtype).name}', time.time() -t1)
res = do_matmul_simple(A,B)
res.block_until_ready()
t1 = time.time()
res = do_matmul_simple(A,B)
res.block_until_ready()
print(f'jit 100 do_matmul_simple for dtype {np.dtype(dtype).name}', time.time() - t1)
print('######## timings for matrix-vector ###########')
print(' ---------- float64 --------------')
run_timings_dot(np.float64, 1000)
print(' ---------- complex128 --------------')
run_timings_dot(np.complex128, 1000)
print()
print()
print('######## timings for matrix-matrix ###########')
print(' ---------- float64 --------------')
run_timings_matmul_simple(np.float64, 400)
print(' ---------- complex128 --------------')
run_timings_matmul_simple(np.complex128, 400)
update: disabling double precision seems to increase the slowdown to ~ 100x
What hardware platform are you using?
I'm running it on a macbook pro 2018, i7 quad core and 16 Gb DDR3 RAM
What timings from your script do you see on your machine?
When I run it in Google Colab on a CPU kernel, I see:
######## timings for matrix-vector ###########
---------- float64 --------------
loop over 100 matrix-vector muls in dtype float64 0.34077000617980957
jit 100 do_matvec_simple for dtype float64 0.0012781620025634766
---------- complex128 --------------
loop over 100 matrix-vector muls in dtype complex128 0.16057896614074707
jit 100 do_matvec_simple for dtype complex128 0.002154827117919922
######## timings for matrix-matrix ###########
---------- float64 --------------
loop over 100 matrix-matrix muls in dtype float64 1.5076098442077637
jit 100 do_matmul_simple for dtype float64 0.015089273452758789
---------- complex128 --------------
loop over 100 matrix-matrix muls in dtype complex128 10.45351505279541
jit 100 do_matmul_simple for dtype complex128 0.10727953910827637
In summary (looking at the jit numbers):
This is roughly within expectations for me. Naively, multiplication of complex numbers involves 4x more multiplications ((a+bi)(c+di) = (ac-bd) + (bc+ad)i), so any improvement over that for matrix/vector math is a win.
Modern CPUs have special instructions for dense matrix/matrix multiplication (likely specialized for particular dtypes such as float32), so I'm not surprised that the gap there is larger.
Thanks for replying so quickly!
Yes, FLOP count ratio is (M+M)*M**2/((6*M + 2*M)*M**2)=4, which is the baseline of my expectation, and is also what you get from MKL kernels.
I attached the runtimes on my macbook and on my linux machine below. I changed the timing of the naive for loop to report 100*(the median runtime of 100 matmuls) to make it more robust against outliers (see change code below).
On both machines I see that floatX runtimes are substantially smaller than complexX runtimes (significantly more than 4x, or even 7x). However, it seems that the slowdown is less pronounced on my linux machine.
The slowdown is also consistent between the jitted version and the naive for loop.
@shoyer it looks like Colab runtimes for complex128 are roughly consistent with mine, while they are substantially slower for float64.
The results on Colab for single precision (see below) are again very similar to the ones I get on e.g. my linux macine. Not sure what's going on, but it's a bit puzzling.
In summary:
__slowdown of float vs complex matmuls (simple for loop, no jit)__
| | macbook | thinkpad| Colab|
|--------|----------|-----------|---------|
|single precision | ~200 | ~100 | ~ 50|
|double precision| ~30 | ~10 | ~7 |
Here are the timings from my macbook pro:
######## timings for matrix-matrix ###########
---------- float --------------
loop over 100 matrix-matrix muls in dtype float32 0.055694580078125
do_matmul_simple (100 matmuls, jitted) for dtype float32 0.0010039806365966797
---------- complex --------------
loop over 100 matrix-matrix muls in dtype complex64 9.781551361083984
do_matmul_simple (100 matmuls, jitted) for dtype complex64 0.10428619384765625
######## timings for matrix-matrix ###########
---------- float --------------
loop over 100 matrix-matrix muls in dtype float64 0.25349855422973633
do_matmul_simple (100 matmuls, jitted) for dtype float64 0.0029239654541015625
---------- complex --------------
loop over 100 matrix-matrix muls in dtype complex128 6.962347030639648
do_matmul_simple (100 matmuls, jitted) for dtype complex128 0.08057284355163574
Here are the timings on my Linux machine (Lenovo thinkpad running Ubuntu 18.04):
######## timings for matrix-matrix ###########
---------- float --------------
loop over 100 matrix-matrix muls in dtype float32 0.11262893676757812
do_matmul_simple (100 matmuls, jitted) for dtype float32 0.0013873577117919922
---------- complex --------------
loop over 100 matrix-matrix muls in dtype complex64 8.658218383789062
do_matmul_simple (100 matmuls, jitted) for dtype complex64 0.09708619117736816
######## timings for matrix-matrix ###########
---------- float --------------
loop over 100 matrix-matrix muls in dtype float64 0.6018519401550293
do_matmul_simple (100 matmuls, jitted) for dtype float64 0.00616002082824707
---------- complex --------------
loop over 100 matrix-matrix muls in dtype complex128 7.223522663116455
do_matmul_simple (100 matmuls, jitted) for dtype complex128 0.07247376441955566
On Colab I get the following for single precision:
######## timings for matrix-matrix ###########
---------- float --------------
loop over 100 matrix-matrix muls in dtype float32 0.23915767669677734
do_matmul_simple (100 matmuls, jitted) for dtype float32 0.002748727798461914
---------- complex --------------
loop over 100 matrix-matrix muls in dtype complex64 9.820866584777832
do_matmul_simple (100 matmuls, jitted) for dtype complex64 0.10582566261291504
Here's the timing code
import numpy as np
import jax
from jax import config
config.update('jax_enable_x64',False)
import time
@jax.jit
def do_matmul_simple(matrix1, matrix2):
res = 0
for _ in range(100):
res += matrix1 @ matrix2
return res
def run_timings_matmul_simple(dtype, D):
A = jax.numpy.array(np.random.rand(D,D).astype(dtype))
B = jax.numpy.array(np.random.rand(D,D).astype(dtype))
ts = []
for _ in range(100):
t1=time.time()
res = jax.numpy.matmul(A,B)
res.block_until_ready()
ts.append(time.time() - t1)
print(f'loop over 100 matrix-matrix muls in dtype {np.dtype(dtype).name}', np.median(ts)*100)
res = do_matmul_simple(A,B)
res.block_until_ready()
t1 = time.time()
res = do_matmul_simple(A,B)
res.block_until_ready()
print(f'do_matmul_simple (100 matmuls, jitted) for dtype {np.dtype(dtype).name}', time.time() - t1)
print('######## timings for matrix-matrix ###########')
print(' ---------- float --------------')
run_timings_matmul_simple(np.float32, 400)
print(' ---------- complex --------------')
run_timings_matmul_simple(np.complex64, 400)
Yes, I think there's a bug here. On CPU, XLA is falling back to a naive implementation of matmul for complex types instead of calling into an optimized implementation as it does for floating point types. This should be easy to fix.
This bug should be fixed in jaxlib 0.1.48 or newer (jaxlib 0.1.50 was released yesterday).
Thanks for fixing!
Most helpful comment
Yes, I think there's a bug here. On CPU, XLA is falling back to a naive implementation of matmul for complex types instead of calling into an optimized implementation as it does for floating point types. This should be easy to fix.