I am doing some experiments with probabilistic programming in jax. It was quite easy lexically to port the project from autograd to jax, but I am finding a performance hit on the CPU, especially for small models. Specifically I see a ~100x slowdown in computing a log probability compared to autograd, and a 1.3x slowdown computing a gradient compared to autograd. Those experiments are below.
Putting all this together, I see a ~6x slowdown in running hamiltonian monte carlo with jax compared to autograd on the CPU (I wrapped the gradient with jit, but have not spent much more time tuning).
I see #427 suggests that benefits are seen for larger functions that are not dominated by dispatch. Are there any other suggestions or best practices for computations like these? I may go through and restructure the program to allow jit to be used more often to see how much that helps.
from jax import jit
import jax.numpy as jnp
import numpy as onp
import autograd.numpy as anp
def logp(x):
"""N(x | 1., 0.1)"""
return 0.5 * (onp.log(2 * onp.pi * 0.1 * 0.1) + ((x - 1.) / 0.1) ** 2)
@jit
def jlogp_jit(x):
"""N(x | 1., 0.1)"""
return 0.5 * (jnp.log(2 * jnp.pi * 0.1 * 0.1) + ((x - 1.) / 0.1) ** 2)
def jlogp(x):
"""N(x | 1., 0.1)"""
return 0.5 * (jnp.log(2 * jnp.pi * 0.1 * 0.1) + ((x - 1.) / 0.1) ** 2)
def alogp(x):
"""N(x | 1., 0.1)"""
return 0.5 * (anp.log(2 * anp.pi * 0.1 * 0.1) + ((x - 1.) / 0.1) ** 2)
%timeit logp(0.1)
# 1.14 µs ± 92.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit jlogp_jit(0.1)
# 214 µs ± 21.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit jlogp(0.1)
# 674 µs ± 33.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit alogp(0.1)
# 2.16 µs ± 157 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
This holds true to a lesser extent for gradients as well:
from jax import grad as jgrad
from autograd import grad as agrad
adlogp = agrad(alogp)
jdlogp = jgrad(jlogp)
jdlogp_jit = jit(jgrad(jlogp))
jdlogp_jit_jit = jit(jgrad(jlogp_jit))
%timeit adlogp(0.1)
# 156 µs ± 8.96 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit jdlogp(0.1)
# 3.44 ms ± 500 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit jdlogp_jit(0.1)
# 232 µs ± 25.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit jdlogp_jit_jit(0.1)
# 204 µs ± 13 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Just as a followup, for sampling in high dimensions (this is a 250-dimensional gaussian), jax is about 10x faster than similar autograd code.
import jax.numpy as np
from minimc import neg_log_mvnormal, hamiltonian_monte_carlo
dim = 250
mu = np.zeros(dim)
cov = 0.8 * np.ones((dim, dim)) + 0.2 * np.eye(dim)
neg_log_p = neg_log_mvnormal(mu, cov)
samples = hamiltonian_monte_carlo(1_000, neg_log_p, np.zeros(dim), path_len=10., step_size=0.2)
I think that @mattjj has written a great explanation for this issue at this thread.
(I see that you are also interested in implementing Hamilton Monte Carlo using JAX. Currently, we have built HMC/NUTS with adaptation with JAX. According to our initial benchmark on a small model, the performance of jax is quite impressive. ;) In addition, functional style of JAX has motivated us converting the recursive nature of NUTS into an iterative scheme (AFAIK we are the first who did this converting job, I'm not sure why no one did it previously). As a result, NUTS just has a small overhead over HMC. I don't mean to sell things but just want to ask if you are interested in building a great HMC framework for the JAX community? If so can we setup a collaboration? :)
(Oppss... @ColCarroll I just go over your minimc repo and see that we are doing an overlapping work! >"< Your plan is quite nice. If you go further with your roadmap and implement algorithms in NUTS papers, I think that you will observe that the overhead of NUTS is pretty large using JAX, especially when your tree is big. If that is the case, please let me know. I am happy to write a note on how to transform recursive NUTS to iterative NUTS for you. I believe that although our work are overlapping, different apis can have a benefit that they give more choices for users. Cheers,)
Thank you for the kind words! I spotted numpyro yesterday and it looks very nice. Possibly the biggest design difference is I am trying to use this only as a reference and learning implementation, though numpyro looks like you could actually use it for real work! I was impressed at autograd for letting me write concise and readable code, and was hoping not to have to change it too much to use jax.
The iterative NUTS implementation is very interesting - I am taking a closer look at that now. I have been looking at transforming the tree doubling into an iterative scheme, and I was impressed to see you all had done it. I would love to read any thoughts you had on it (in fact, I know a few people who would...)
I had a (very short) conversation about it with one of the Stan developers, and he was not convinced that the recursion ever goes deep enough for there to be a performance benefit to it (I think both Stan and PyMC3 use a default of 10 doublings before issuing a warning, which really is not that deep). That said, it seems like it gives you a bit more control in searching for the right scale for a trajectory, and might make the code more readable/maintainable, which is not a small thing.
I'm happy that you are also interested in iterative NUTS! I will share with you a note to translate recursive to iterative in a couple of days. :)
About the benefit, yes, it gives us more control over small models, especially when using JAX. The overhead of JAX is pretty large. For example, if each leapfrog step takes 100micro second for overhead cost, then it takes 100ms to build a full tree (with depth=10). So it takes 100s to get 1000 samples, which is so costly :( When using iterative NUTS, we can jit the whole trajectory, which in turns have the benefit that the overhead is just in the range of 100micro second. We did a benchmark here showing that it took 1s to build a full tree with recursive algorithm while only 500 microsecond with iterative algorithm. But I also think that the algorithm is less benefit in other frameworks, where the overhead for each leapfrog step is small (says about 1micro second). Or when the computational cost of each leapfrog step is in the range of miliseconds (as in the covertype example of Simple, Distributed, and Accelerated Probabilistic Programming paper).
(closing because this discussion was very useful, thank you!)
@ColCarroll I just write a note here about iterative NUTS. Hope that you will enjoy it. ^^ (and please forgive my language if it makes you hard to understand the details, I am learning to improve my writing style)
cc @neerajprad
@ColCarroll - As @fehiepsi said, the main reason for us going the iterative route wasn't due to code readability / any framework independent performance benefits (which you would be right to be skeptical about), but to simply be able to use JAX primitives and JIT the entire tree building step using a jax.while_loop which results in an order of magnitude difference in performance.
Most helpful comment
Thank you for the kind words! I spotted numpyro yesterday and it looks very nice. Possibly the biggest design difference is I am trying to use this only as a reference and learning implementation, though numpyro looks like you could actually use it for real work! I was impressed at
autogradfor letting me write concise and readable code, and was hoping not to have to change it too much to usejax.The iterative NUTS implementation is very interesting - I am taking a closer look at that now. I have been looking at transforming the tree doubling into an iterative scheme, and I was impressed to see you all had done it. I would love to read any thoughts you had on it (in fact, I know a few people who would...)
I had a (very short) conversation about it with one of the Stan developers, and he was not convinced that the recursion ever goes deep enough for there to be a performance benefit to it (I think both Stan and PyMC3 use a default of 10 doublings before issuing a warning, which really is not that deep). That said, it seems like it gives you a bit more control in searching for the right scale for a trajectory, and might make the code more readable/maintainable, which is not a small thing.