Jax: np.linalg.inv support

Created on 10 Dec 2018  路  9Comments  路  Source: google/jax

I've been trying to implement Gaussian Processes Regression, which require the calculation of a matrix inverse. With regular numpy I would use np.linalg.inv, but I can't find this function back in jax.

Everything else is working as expected, and I can use np.linalg.inv for basic calculations.
Unfortunately, the use of np.linalg.inv keeps me from using grad to calculate gradients, which would be the most exciting part of the whole implementation!

I would love to contribute a PR if someone can tell me where to start.

enhancement

Most helpful comment

Here's a quick example of sampling from a GP with a squared exponential covariance function using jax.

```
import numpy as onp
import matplotlib.pyplot as plt
import jax.numpy as np
import jax.random as random
from jax.experimental import lapax

key = random.PRNGKey(0)
numpts = 50

def dist(x1, x2):
distance = -2. * np.dot(x1, x2.T) + np.sum(x22, axis=1) + np.sum(x2, axis=1)[:, None]
return distance

def cov(x1, x2):
return np.exp(-dist(x1, x2))

x = onp.linspace(0, 1., numpts)[:, None]

K = cov(x, x) + onp.eye(x.shape[0]) * 1e-6
L = lapax.cholesky(K + onp.eye(x.shape[0]) * 1e-6)

normal_samp = random.normal(key, shape=(x.shape[0], 10))
y_hat = np.dot(L, normal_samp)
plt.plot(x, y_hat)```

All 9 comments

Thanks for bringing this up, and for the offer to help! Good np.linalg linear algebra support was one of the best parts of Autograd, with a lot of users choosing Autograd just for that. It's also near and dear to my heart because working on the linear algebra support is how I first got into developing Autograd, and it's critical to my kind of machine learning research. JAX needs to get there too!

I believe the main challenge is getting XLA to call into backend-specific linear algebra routines (e.g. in LAPACK and MAGMA) using the CustomCall HLO on CPU and GPU (and, more generally, call into the XLA client library for TPU). @hawkinsp has already started to look into what we need. Once we can generate calls into these routines, we can use similar rules as in Autograd for differentiation.

An alternative route is just to implement the algorithms we need in terms of lax primitives, since we can already compile and differentiate all of those (on any backend). That's the approach taken in the jax.experimental.lapax module, which just has cholesky and solve_triangular. As you can see from e.g. the cholesky routine, these algorithms aren't so bad to implement ourselves, but they'll likely be slower on CPU and GPU than the extremely well optimized LAPACK and MAGMA kernels. (Even for TPU, we'd rather reuse HLO implementations in the nascent XLA client library than duplicate that effort in JAX, as convenient as it is to write this code in Python.)

We'd love to get contributions on this. I think the best course of action is to wait for us to sort out our CustomCall / client library story, which we might be able to do this week (fingers crossed), and to make it work with one linear algebra routine (e.g. cholesky). Once we get one example working, it would be really helpful for contributors to dive in and help us set up the rest, along with their derivatives.

How does that sound to you?

To unstick your work for now, given that you're working on GPs you might be able to work with the cholesky and solve_triangular routines in jax.experimental.lapax. You might also be able to implement something like a CG iteration, though that path has its own bumps that need ironing out (@jaspersnoek mentioned he'd had some success in this direction).

One other thing: it's not documented yet, but JAX has a configuration option for enabling 64-bit dtypes, which are off by default (instead capping everything to 32 bits). You can switch it on like this:

import jax.numpy as np

from jax.config import config
config.update("jax_enable_x64", True)

print np.dot(np.zeros(2), np.zeros(2)).dtype   # should print 'float64'

Here's a quick example of sampling from a GP with a squared exponential covariance function using jax.

```
import numpy as onp
import matplotlib.pyplot as plt
import jax.numpy as np
import jax.random as random
from jax.experimental import lapax

key = random.PRNGKey(0)
numpts = 50

def dist(x1, x2):
distance = -2. * np.dot(x1, x2.T) + np.sum(x22, axis=1) + np.sum(x2, axis=1)[:, None]
return distance

def cov(x1, x2):
return np.exp(-dist(x1, x2))

x = onp.linspace(0, 1., numpts)[:, None]

K = cov(x, x) + onp.eye(x.shape[0]) * 1e-6
L = lapax.cholesky(K + onp.eye(x.shape[0]) * 1e-6)

normal_samp = random.normal(key, shape=(x.shape[0], 10))
y_hat = np.dot(L, normal_samp)
plt.plot(x, y_hat)```

Nice, thanks so much @JasperSnoek!

I think eye recently got into jax.numpy, I'll add linspace now too so that onp import won't be necessary. EDIT: nevermind, looks like it's in there, but we should clean up how we handle those constant-creation functions.

I must admit I did see the cholesky decomposition in lapax already, but did not make the connection to using it for this yet! I am going to dive into this right now.

For completeness sake, here is what I have done so far:

import jax.numpy as np
from jax import random, grad

from numpy import log10, diag
from numpy.linalg import inv, det

import matplotlib.pyplot as plt

LOG2PI = log10(2 * np.pi)

def ard_kernel(x1, x2, length_scale):
    return np.exp(-length_scale * ((x1.T - x2) ** 2))

def gp_predict(x_train, y_train, x_test, variance, length_scale, kernel = ard_kernel):
    k_x_x = kernel(x_train, x_train, length_scale)
    k_x_xs = kernel(x_train, x_test, length_scale)
    k_xs_x = k_x_xs.T
    k_xs_xs = kernel(x_test, x_test, length_scale)

    v_inv = inv(k_x_x + variance * np.eye(x_train.size))
    q = np.dot(k_x_xs, v_inv)

    mu = np.dot(q, y_train).T[0]
    cov = k_xs_xs - np.dot(q, k_xs_x)
    sigma = diag(cov)

    return mu,sigma

def gp_log_marginal_likelihood(x_train, y_train, variance, length_scale, kernel = ard_kernel):
    k_x_x = kernel(x_train, x_train, length_scale)
    k_var = k_x_x + variance * np.eye(x_train.size)
    v_inv = inv(k_var)

    data_fit = np.dot(np.dot(-.5 * y_train.T, v_inv), y_train)
    complexity = .5 * np.log(det(k_var))
    size_correction = .5 * x_train.size * LOG2PI
    return (-data_fit - complexity - size_correction)[0][0]

length_scale = 100
variance = .1
size = 50

random_key = random.PRNGKey(0)

x_train = 2 * random.uniform(random_key, shape=(size,)).reshape((-1,1))
y_train = np.sin(x_train * 10) + random.normal(random_key, shape=(size,)).reshape((-1,1))

x_test = np.arange(-.5,2.5,.05).reshape((-1,1))

mu, sigma = gp_predict(x_train, y_train, x_test, variance, length_scale)

plt.plot(x_test, mu)
plt.fill_between(x_test.T[0], mu + 2 * sigma, mu - 2 * sigma, alpha = .5)

plt.scatter(x_train.T, y_train.T, c = 'black')

This is following Rasmussen & Williams quite literally, which is why I resorted to np.linalg.inv in the first place. In addition there are a few functions that I haven't found in JAX yet: log10,diag and np.linalg.det.

From what I understand, and have seen from examples, getting the predictive mean using the cholesky decomposition uses np.linalg.solve. What would be the JAX approach for this?

The other function implemented in jax.experimental.lapax is a triangular solve, which you can use in place of np.linalg.solve when the coefficient matrix is triangular (as a Cholesky factor is). You can also compute the log determinant easily given the Cholesky factor, using the facts that the determinant of a triangular matrix is the product of the diagonal entries, and the fact that the determinant of a product of two matrices is the product of their determinants. Those are the most efficient ways to implement these computations in NumPy too (because Cholesky requires fewer FLOPs than LU, and because reusing the factor saves a lot of work in the logdet calculation).

We should add direct support for np.log10 and np.diag, but for now you can get them by computing np.log(x) / np.log(10) and by using fancy indexing, respectively.

Awesome! Almost there, I just need to figure out how to get the correct arguments for lapax.solve_triangular:

def solve_triangular(a, b, left_side, lower, trans_a, block_size=1):
  """An unrolled triangular solve."""
  return _solve_triangular_right(LapaxMatrix(a, block_size),
                                 LapaxMatrix(b, block_size),
                                 left_side, lower, trans_a).ndarray

What are left_side, lower and trans_a?

left_side, lower, and trans_a have roughly the same meaning they do in LAPACK or scipy:
http://www.netlib.org/lapack/explore-html/de/da7/dtrsm_8f_source.html

There's also some documentation on a similar C++ API here:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/lib/triangular_solve.h

I'm also actively working on adding cholesky and solve_triangular implementations to JAX, with the standard numpy/scipy APIs.

PR #110 added a first cut at np.linalg.inv. The performance probably isn't optimal on CPU or GPU right now, but it might suffice for your needs. Try it out and see if it works for you!

(The next step to improve performance is to link in a LAPACK version of the kernels, e.g., MKL or OpenBlas, but let's consider that in another issue.)

Was this page helpful?
0 / 5 - 0 ratings

Related issues

froystig picture froystig  路  34Comments

shyoshyo picture shyoshyo  路  26Comments

ericmjl picture ericmjl  路  53Comments

shoyer picture shoyer  路  35Comments

samuela picture samuela  路  27Comments