I am trying to implement something similar to deflated power iteration as described in https://github.com/noahgolmant/pytorch-hessian-eigenthings/blob/master/hessian_eigenthings/power_iter.py .
I need to use an operator within a loop. How can I pass a function as an argument to another function while still being able to jit it and run it efficiently?
You have a couple of options:
jit the higher level function directly. Inside, jit a function that calls the higher level function.jax.tree_util.Partial(), which can be passed as an argument into a jit compiled function. But beware, this will break if your function is a closure (depending on arguments not defined in the function).You can also use static_argnums for this (see the jit docstring):
from __future__ import print_function
from functools import partial
from jax import jit
@partial(jit, static_argnums=(0,))
def app(f, x):
return f(x)
print(app(lambda x: 2 * x, 3))
A "static argument" means that (1) the argument can be any Python object, e.g. a callable, and (2) recompilation is triggered for every new value of the argument (based on __eq__/__hash__ if the object is hashable, or object identity if it's not).
This is pretty similar to just using lexical closure along with the first solution that @shoyer pointed out, but can lead to more compilation cache hits.
@kunc hope that together with the above options answers your question! Please reopen the issue (or open a new one) if it doesn't.
You have a couple of options:
- Don't
jitthe higher level function directly. Inside,jita function that _calls_ the higher level function.- Wrap the passed function in
jax.tree_util.Partial(), which can be passed as an argument into ajitcompiled function. But beware, this will break if your function is a closure (depending on arguments not defined in the function).
It would be nice to document these usage tips.
Most helpful comment
It would be nice to document these usage tips.