jax.numpy.linalg ops missing GPU implementation

Created on 8 Jun 2019  路  7Comments  路  Source: google/jax

I often hit messages like
"Singular value decomposition is only implemented on the CPU backend"
(https://github.com/google/jax/blob/27746b8c73f9ca9928da5da40b7382ae648a5f8d/jax/lax_linalg.py)
for many jax.numpy.linalg ops. Examples I've hit so far are when calling:
jax.numpy.linalg.svd;
jax.numpy.linalg.eigh.

It would be nice to be able to run on GPU.

Thanks!

enhancement

All 7 comments

Thanks for raising this. Given your specific request here, what do you think about changing the issue title to be a request to get linalg support on GPU? That's probably easier for us to add.

Done, thank you!

I want to know if jax.numpy.linalg.inv ops can run on the GPU?

As it happens, jax.numpy.linalg.inv should run on GPU right now! (It might not be terribly fast, since it's using a QR decomposition instead of an LU decomposition, and one that isn't necessarily that well tuned.)

Another thing missing is `np.linalg.solve':

solve = np.linalg.solve(np.eye(3), np.ones(3))
Traceback (most recent call last):
  File "/opt/intellij-ue-2019.1/plugins/python/helpers/pydev/_pydevd_bundle/pydevd_exec2.py", line 3, in Exec
    exec(exp, global_vars, local_vars)
  File "<input>", line 1, in <module>
  File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/numpy/linalg.py", line 237, in solve
    lu, pivots = lax_linalg.lu(a)
  File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/lax_linalg.py", line 53, in lu
    lu, pivots = lu_p.bind(x)
  File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/core.py", line 136, in bind
    return self.impl(*args, **kwargs)
  File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/lax_linalg.py", line 358, in lu_impl
    lu, pivot = xla.apply_primitive(lu_p, operand)
  File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/interpreters/xla.py", line 52, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *abstract_args, **params)
  File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/util.py", line 174, in memoized_fun
    ans = cache[key] = fun(*args, **kwargs)
  File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/interpreters/xla.py", line 58, in xla_primitive_callable
    built_c = primitive_computation(prim, *shapes, **params)
  File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/util.py", line 174, in memoized_fun
    ans = cache[key] = fun(*args, **kwargs)
  File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/interpreters/xla.py", line 69, in primitive_computation
    xla_result = translation_rule(prim)(c, *xla_args, **params)
  File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/lax_linalg.py", line 363, in lu_translation_rule
    "LU decomposition is only implemented on the CPU backend")
NotImplementedError: LU decomposition is only implemented on the CPU backend

There is now an LU decomposition implementation that works on GPU. However, it may not be the most performant (it's implemented in JAX itself). We still would do well to add a specialized GPU implementation that calls cuSolver or MAGMA.

(FYI, slight overlap with https://github.com/google/jax/issues/723)

Was this page helpful?
0 / 5 - 0 ratings

Related issues

sschoenholz picture sschoenholz  路  3Comments

clemisch picture clemisch  路  3Comments

madvn picture madvn  路  3Comments

sursu picture sursu  路  3Comments

lonelykid picture lonelykid  路  3Comments