Jax: Multivariate Normal

Created on 22 Sep 2019  路  5Comments  路  Source: google/jax

Hi all,

I noticed that sampling from predefined distributions works a bit different in jax than it does in numpy.

Is it possible to use the jax.random module to sample from a multivariate normal distribution (with dense covariance matrix)? If not, is there any other way?

So far I only managed to sample from one with a unit variance, using jax.random.normal.

Cheers,

enhancement good first issue

Most helpful comment

In JAX, there is a pending PR #269 which does the job.

All 5 comments

Thanks for the question!

The best way to sample from a multivariate Gaussian is to sample a vector of iid unit-variance zero-mean Gaussians (which you can do with jax.random.normal), and then perform an affine transformation to produce the mean and covariance structure you want.

For example, let's take the zero-mean case (since it's easy to add a mean vector in later). If x ~ N(0, I) is a vector of unit-variance zero-mean Gaussians, then y = A x has covariance E[yy'] = E[(A x)(A x)'] = E[A x x' A'] = A E[xx'] A' = AA'.

So if we had a target covariance Q and mean mu, the procedure would look something like

  1. sample a standard Gaussian like v ~ N(0, I)
  2. return a sample w = A v + mu, where A is a square root of Q such that AA' = Q.

See also this section on Wikipedia.

There are many matrix square roots (all related by an orthogonal transformation on the right), and any one will work. The standard one to use here would be the Cholesky square root, since (IIRC) it's the cheapest to compute in terms of FLOPs. That's usually written as L = chol(Q) where LL' = Q. (Another option is the symmetric square root, which basically looks like diagonalizing Q, taking the square root of the eigenvalues, and then multiplying the result back out. Fun fact, if you take the QR of the symmetric square root, then R' is the Cholesky!)

Since we have vmap we don't need to worry about handling batch dimensions by hand. Maybe a good API for a multivariate normal would be

# Q is an (n, n) matrix
# mu is an (n,) vector
w = random.multivariate_normal(key, Q, mu)

Thanks for the answer!

Sorry for the confusion. I was mostly interested to know whether there is a multivariate normal function already implemented in jax, since I couldn't find one.

Are there any plans to add more distributions (including a multivariate normal) to the jax.random?

P.S. I ended up doing the Cholesky decomposition.

In case you find it useful, we have a multivariate normal distribution in numpyro that's basically just doing what Matt proposed above (the remaining code is just to comply with our distributions API). cc. @fehiepsi.

In JAX, there is a pending PR #269 which does the job.

Thanks a lot! I think this answers my question.

Was this page helpful?
0 / 5 - 0 ratings