Jax: Add multivariate normal pdf evalutation to jax.scipy

Created on 26 Feb 2020  路  3Comments  路  Source: google/jax

It would be great to have a Multivariate gaussian pdf/logpdf implementation, similar to the univariate version in jax.scipy.stats.norm. I am currently working with this hacky function:

@jit
def multi_gauss_logpdf(x, mean, cov):
    """ Calculate the probability density of a
        sample from the multivariate normal. """
    D = mean.shape[0]
    (sign, logdet) = np.linalg.slogdet(cov)
    p1 = D*np.log(2*np.pi) + logdet
    p2 = (x-mean).T @ np.linalg.inv(cov) @ (x-mean)
    return -1./2 * (p1 + p2)

batch_logpdf = vmap(multi_gauss_logpdf, in_axes=(0, None, None))

My lax/primitive knowledge is still fairly limited but I will try to put together a pr. Any recommendations how to speed things up?

enhancement

Most helpful comment

@mattjj thanks for coming back and sorry for the late response. Yes, it is on my todo-list. My focus is suffering a little. Should be able to get this done in the next days. Okaydokey than I have learned some new computational linear algebra tricks ;)

P.S.: I wrote a little JAX intro tutorial (https://roberttlange.github.io/posts/2020/03/blog-post-10/). Thanks for the great project and all the effort!

All 3 comments

@RobertTLange are you actively working on a PR? If not, I can probably help with this one, since I've written similar code too many times in the past!

It's usually better (faster and more numerically stable) to do a Cholesky and a triangular solve, and you can compute the log det in terms of the diagonal elements of the Cholesky too. Here's an example.

@mattjj thanks for coming back and sorry for the late response. Yes, it is on my todo-list. My focus is suffering a little. Should be able to get this done in the next days. Okaydokey than I have learned some new computational linear algebra tricks ;)

P.S.: I wrote a little JAX intro tutorial (https://roberttlange.github.io/posts/2020/03/blog-post-10/). Thanks for the great project and all the effort!

Awesome tutorial, and artwork!

Actually, I just noticed: a multivariate normal logpdf function seems to be checked in already from #268, but it's inefficient because of how it computes inv and det (it even computes inv twice!). Looks like I LGTM'd it though, probably thinking we could improve the efficiency and numerics in follow-up work.

Does that function work for you? if so, we can change this issue title to be about improving it.

Was this page helpful?
0 / 5 - 0 ratings