It would be great to have a Multivariate gaussian pdf/logpdf implementation, similar to the univariate version in jax.scipy.stats.norm. I am currently working with this hacky function:
@jit
def multi_gauss_logpdf(x, mean, cov):
""" Calculate the probability density of a
sample from the multivariate normal. """
D = mean.shape[0]
(sign, logdet) = np.linalg.slogdet(cov)
p1 = D*np.log(2*np.pi) + logdet
p2 = (x-mean).T @ np.linalg.inv(cov) @ (x-mean)
return -1./2 * (p1 + p2)
batch_logpdf = vmap(multi_gauss_logpdf, in_axes=(0, None, None))
My lax/primitive knowledge is still fairly limited but I will try to put together a pr. Any recommendations how to speed things up?
@RobertTLange are you actively working on a PR? If not, I can probably help with this one, since I've written similar code too many times in the past!
It's usually better (faster and more numerically stable) to do a Cholesky and a triangular solve, and you can compute the log det in terms of the diagonal elements of the Cholesky too. Here's an example.
@mattjj thanks for coming back and sorry for the late response. Yes, it is on my todo-list. My focus is suffering a little. Should be able to get this done in the next days. Okaydokey than I have learned some new computational linear algebra tricks ;)
P.S.: I wrote a little JAX intro tutorial (https://roberttlange.github.io/posts/2020/03/blog-post-10/). Thanks for the great project and all the effort!
Awesome tutorial, and artwork!
Actually, I just noticed: a multivariate normal logpdf function seems to be checked in already from #268, but it's inefficient because of how it computes inv and det (it even computes inv twice!). Looks like I LGTM'd it though, probably thinking we could improve the efficiency and numerics in follow-up work.
Does that function work for you? if so, we can change this issue title to be about improving it.
Most helpful comment
@mattjj thanks for coming back and sorry for the late response. Yes, it is on my todo-list. My focus is suffering a little. Should be able to get this done in the next days. Okaydokey than I have learned some new computational linear algebra tricks ;)
P.S.: I wrote a little JAX intro tutorial (https://roberttlange.github.io/posts/2020/03/blog-post-10/). Thanks for the great project and all the effort!