Jax: Any way to avoid recompiling of jit'ed sub-functions?

Created on 14 Oct 2020  路  13Comments  路  Source: google/jax

Suppose I have two functions

def main_function(x):
    return x+1

main_function_jit = jit(main_function)

def wrapper(x):
    return main_function_jit(x)

Theoretically speaking, is there any way to avoid recompiling of main_function_jit when I do jit(wrapper)?

In my work I have a big function that can be broken down into sub-functions, and some of these sub-functions can actually be reused in other big functions. Problem I encounter is that after compiling these sub-functions, it doesn't help to speed up compilation of the big function. Any idea?

question

Most helpful comment

I believe these lines are where you can find JAX's backend compilation cache:
https://github.com/google/jax/blob/32010968992ff88c9c065ff1fa5ba6cbbfd21641/jax/interpreters/xla.py#L245-L247

It's store in an LRU cache on xla_primitive_callable, not on the function objects themselves. That might be feasible, but memory management for caching compilation is already a little tricky...

All 13 comments

My understanding is that this isn't possible with JAX/XLA today. I agree it would be really nice to have.

In theory it would be possible if we created an XLA "CustomCall" for the inner jit decorated function (or with some sort of lower-level change in XLA itself).

Arguably this is the main reason why use-cases like lbfgs_optimize(odeint(fun, ...)) (https://github.com/google/jax/issues/3847) are so slow. XLA is recompiling inner functions many redundant times.

@shoyer Got it. Another related question about compilation is this: I have a JIT'ed function that I specify with static_argnums. The static arguments are really just constant arrays. Yet I find that the check for whether a recompilation is needed is super strict. Even if I create the exact same array (same in terms of its memory content), it will also trigger a recompilation. Is there a way to work around this, perhaps telling JAX that there's really no need to check a certain static argument because the values in the array are guaranteed to be same?

@gnool that's definitely a foot-gun that we want to revise. See #2813 and #3712.

Basically, for array types like numpy.ndarrays or JAX's DeviceArrays, jax.jit will recompile whenever the _object identity_ of a static_argnums argument is new. Because array types are not hashable, right now static_argnums will silently handle arrays (and other unhashable objects) by object id. So using array types with static_argnums is a recipe for recompilation and slowness. In the near future, we plan to make static_argnums + unhashable type = error. [EDIT: edited to improve phrasing, which was previously pasted from chat comments and wasn't very clear]

Example:

In [1]: import jax.numpy as jnp

In [2]: x = jnp.array([1., 2., 3.])
/usr/local/google/home/mattjj/packages/jax/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')

In [3]: def f(x):
   ...:     print('re-tracing!')
   ...:     return x ** 2
   ...:

In [4]: from jax import jit

In [5]: jit_f = jit(f, static_argnums=(0,))

In [6]: jit_f(x)
re-tracing!
Out[6]: DeviceArray([1., 4., 9.], dtype=float32)

In [7]: x = jnp.array([1., 2., 3.])

In [8]: jit_f(x)
re-tracing!
Out[8]: DeviceArray([1., 4., 9.], dtype=float32)

@mattjj If I understand your explanation correctly, there is no way to make static_argnum + unhashable type (arrays) work, right? Which is why it has to at least return an error. Pardon my poor understanding of how JIT works, have to admit I'm still a little puzzled as to why a need of recompilation is needed when the array content (and its metadata except its identity and memory address) is exactly the same. I'm guessing there's a reason as to why a static argument has to be hashable too.

My "workaround" now would be to avoid declaring the arrays as static then. Although the subsequent function call is still fast, it is noticeably slower (perhaps by an order of magnitude) than declaring the arrays as static (and avoiding the recompilation).

why recompilation is needed when the array content (and its metadata except its identity and memory address) is exactly the same

The issue is that we'd have to check that the array content is the same. That is, we could define __eq__ and __hash__ on the DeviceArray class, but it'd be expensive: on every dispatch of the jitted function we'd have to compute a hash of all the static array arguments and possibly also compare them all elementwise to the cache key entries.

Plus, if you want that behavior, you can simulate it yourself without needing jax.jit to change at all. Just wrap the array objects you want to be static and cached on value in an instance of something like this:

class HashableArrayWrapper:
  def __init__(self, val):
    self.val = val
  def __hash__(self):
    return some_hash_function(self.val)  # maybe implement this in jax.numpy to save on transfers
  def __eq__(self, other):
    return isinstance(other, HashableArrayWrapper) and jnp.all(jnp.eq(self.val, other.val))

You could write a wrapper on jit like this:

def gnool_jit(fun, static_array_argnums=()):
  @jit  # EDIT: forgot to use static_argnums here! see comment below
  def callee(*args):
    args = list(args)
    for i in static_array_argnums:
      args[i] = args[i].val
    return fun(*args)

  def caller(*args):
    args = list(args)
    for i in static_array_argnums:
      args[i] = HashableArrayWrapper(args[i])
    return callee(*args)

  return caller

WDYT?

@mattjj If implementing my own __hash__ and __eq__ helps bypass the current object identity check, that's definitely something I can explore. I'm running into this error below, any idea?

jax.traceback_util.FilteredStackTrace: TypeError: Argument '<__main__.HashableArrayWrapper object at 0x7fb39bd44ad0>' of type <class '__main__.HashableArrayWrapper'> is not a valid JAX type

Sorry, I forgot to use static_argnums in my example code above. Oops! Here's a working example:

from functools import partial

from jax import jit
import jax.numpy as jnp

def some_hash_function(x):
  return int(jnp.sum(x))

class HashableArrayWrapper:
  def __init__(self, val):
    self.val = val
  def __hash__(self):
    return some_hash_function(self.val)
  def __eq__(self, other):
    return (isinstance(other, HashableArrayWrapper) and
            jnp.all(jnp.equal(self.val, other.val)))

def gnool_jit(fun, static_array_argnums=()):
  @partial(jit, static_argnums=static_array_argnums)
  def callee(*args):
    args = list(args)
    for i in static_array_argnums:
      args[i] = args[i].val
    return fun(*args)

  def caller(*args):
    args = list(args)
    for i in static_array_argnums:
      args[i] = HashableArrayWrapper(args[i])
    return callee(*args)

  return caller


###


@partial(gnool_jit, static_array_argnums=(0,))
def f(x):
  print('re-tracing!')
  return x ** 2


x = jnp.array([1., 2., 3.])
f(x)
f(x)

x = jnp.array([1., 2., 3.])
f(x)

All we're doing here is making a hashable type, that is has hash and eq methods that implement whatever behavior you want. The object identity behavior we were talking about with static_argnums only kicks in when the type isn't hashable.

@mattjj Thanks for the super quick reply, it works now! For my own benefit, in our discussion above we touched on things like cache. For a static argument, where exactly is the cache for the array's data stored? In the past few days I've been reading what others here have discussed regarding pickling a JIT compiled function (to save compilation time), and so far the answers have been indicating that this is unsupported by JAX. Out of curiosity I have used cloudpickle to pickle a JIT compiled function (it pickled without error), dump it, load it back again, and unsurprisingly it behaves like an uncompiled function (i.e. needs to warm up again). Is this because cache was not being properly stored during the pickling, or perhaps some other more complicated reasons?

I believe these lines are where you can find JAX's backend compilation cache:
https://github.com/google/jax/blob/32010968992ff88c9c065ff1fa5ba6cbbfd21641/jax/interpreters/xla.py#L245-L247

It's store in an LRU cache on xla_primitive_callable, not on the function objects themselves. That might be feasible, but memory management for caching compilation is already a little tricky...

Glad to hear it works!

For a static argument, where exactly is the cache for the array's data stored?

The compilation cache is really just this one memoization decorator, defined here and applied here. Argument values that correspond to static_argnums positions are actually part of the wrapper that makes up a WrappedFun, in particular they're part of the fun argument on this line. (The compiled executable is part of the cache value, rather than the key.)

Here's a little spelunking to show where the array lives in the above example:

In [1]: run gnool.py
/usr/local/google/home/mattjj/packages/jax/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
> /usr/local/google/home/mattjj/packages/jax/jax/linear_util.py(241)memoized_fun()
-> cache = fun_caches.setdefault(fun.f, {})
(Pdb) l
236       fun_caches: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
237       thread_local: threading.local = _CacheLocalContext()
238
239       def memoized_fun(fun: WrappedFun, *args):
240         breakpoint()
241  ->     cache = fun_caches.setdefault(fun.f, {})
242         key = (fun.transforms, fun.params, args)
243         result = cache.get(key, None)
244         if result is not None:
245           ans, stores = result
246           fun.populate_stores(stores)
(Pdb) fun
Wrapped function:
0   : process_env_traces   (xla_call, 0, (('device', None), ('backend', None), ('name', 'callee'), ('donated_invars', ())))
1   : flatten_fun   (PyTreeDef(tuple, [PyTreeDef(tuple, []),PyTreeDef(dict[[]], [])]),)
2   : _argnums_partial   ((), (<jax.util.Hashable object at 0x7f96c89bc690>,))
Core: callee

(Pdb) fun.transforms[2][1][1][0].val.val
DeviceArray([1., 2., 3.], dtype=float32)

However, those implementation details are all subject to change, even in the very near future, so I wouldn't build anything against them.

@shoyer that's the line for the "op-by-op" cache, which is indeed one compilation cache, though the cache for jit is lower down in the same file.

@mattjj @shoyer sounds to me I'm playing with fire then. I'll keep my hands off this and perhaps just stick to pre-compilation at the beginning of the program. Really looking forward to one day where JAX allows us to store the pre-compiled function. Thanks again for all the support and for building this awesome tool!

Thanks for the words of encouragement! We hear you on the need for pre-compiled executables. We should track that on #476.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

proteneer picture proteneer  路  22Comments

samuela picture samuela  路  27Comments

ericmjl picture ericmjl  路  53Comments

ibulu picture ibulu  路  29Comments

ricardobarroslourenco picture ricardobarroslourenco  路  35Comments