Jax: Does lax cond short circuit?

Created on 15 May 2020  路  10Comments  路  Source: google/jax

Hello! I have a function f that wraps two functions, one of which is very expensive (f_1), the other (f_2) is not (they return the same shaped array). If one of the arguments to f is false, we do not need the expensive function. Ultimately, I wrap this inside a jitted function, so I must use lax.cond to split f into f_1 and f_2. Does this buy me anything, or do both sides of the conditional have to be executed because of the way jax works. Thanks!

contributions welcome documentation question

Most helpful comment

I think you mean executed as in evaluated, like to ensure that only one side of the cond was taken rather than both. (If by "executed" you mean "traced" then you can use Python print function calls.)

Since XLA HLO doesn't have errors, without using side-effects then I think the only way to do it is via non-termination, like put an infinite lax.while_loop in one of the branches of the lax.cond.

Otherwise you'd need to use a side-effect. Two readily-available side-effects are time and heat (perhaps those are the same thing...); that is, if f_1 is very expensive perhaps you can decide whether it was executed based on how much time the computation takes, or how much heat your processor generates!

More seriously, there are side-effects in XLA, but we have only exposed them in experimental APIs (infeed and outfeed). I don't necessarily recommend using them right now, but the host callback outfeed mechanism is the _perfect_ API for this (cc @gnecula).

Instead of verifying what was executed, it might be good enough to just look at the XLA HLO programs we send to the compiler, then trust in the XLA HLO operational semantics around conditionals. If that works, I can tell you some ways to print the XLA HLO being generated. Then at least you could see the funny hoisting behavior I alluded to, and also see when it's fixed. Would that be useful?

All 10 comments

Both sides of the conditional are traced, meaning both branch functions are evaluated with tracer objects that don't do any computation in order to discover the operations to be compiled with jit. This should be fast even for the expensive function, since no computation is performed. When the final jitted function is executed with real values, only one branch will be run.

got it, thanks!

One detail to add on: only the operations in each branch that have a data dependence on the explicit branch operands will be delayed; operations with no data dependence on the operands are executed at trace time when not using a jit, and unconditionally when using a jit.

Here's an example:

@jit
def f(x):
  return lax.cond(x > 0,
                  (), lambda _: np.sin(x),
                  (), lambda _: np.cos(x))

On the current master branch, both np.sin(x) and np.cos(x) will be evaluated on each evaluation of f(x). Another way to put it is that they'll be hoisted out of the cond entirely.

To ensure only one side is executed per application of f, we'd need to rewrite it as

@jit
def f(x):
  return lax.cond(x > 0,
                  x, lambda x: np.sin(x),
                  x, lambda x: np.cos(x))

This is a weird quirk of our tracing implementation, and we're working on revising it. Hoping to land a fix in the next couple weeks!

interesting, thanks @mattjj ! One more q: is there a way to determine if certain jax code was executed? Would be very useful for debugging!

I think you mean executed as in evaluated, like to ensure that only one side of the cond was taken rather than both. (If by "executed" you mean "traced" then you can use Python print function calls.)

Since XLA HLO doesn't have errors, without using side-effects then I think the only way to do it is via non-termination, like put an infinite lax.while_loop in one of the branches of the lax.cond.

Otherwise you'd need to use a side-effect. Two readily-available side-effects are time and heat (perhaps those are the same thing...); that is, if f_1 is very expensive perhaps you can decide whether it was executed based on how much time the computation takes, or how much heat your processor generates!

More seriously, there are side-effects in XLA, but we have only exposed them in experimental APIs (infeed and outfeed). I don't necessarily recommend using them right now, but the host callback outfeed mechanism is the _perfect_ API for this (cc @gnecula).

Instead of verifying what was executed, it might be good enough to just look at the XLA HLO programs we send to the compiler, then trust in the XLA HLO operational semantics around conditionals. If that works, I can tell you some ways to print the XLA HLO being generated. Then at least you could see the funny hoisting behavior I alluded to, and also see when it's fixed. Would that be useful?

I think this should be added to the FAQ, or documented explicitly somewhere

@joaogui1 I'm reading these threads to supplement the documentation. The comments in these issues are filled with good insight.

@mattjj Thank you for the always-so-helpful response! This is more than enough to move forward.

I am tempted to close this issue. I do not quite understand what needs to be documented. Is it the fact that the only way to tell if a code was executed is to use id_print? Or is it the hoisting behavior? (The latter is going to change soon)

In general, XLA reserves the right to execute (or not execute) code as long as one cannot tell by the result of the computation.

I am closing for now, please re-open if you feed it needs to stay open.

I think the jax.lax.cond API has changed since this issue was first opened and I'm not sure @mattjj's comments apply in the same way. For example, if I do

import jax

def f(x):
  return jax.lax.cond(x > 0, lambda x: x**2, lambda x: jax.lax.while_loop(lambda x: True, lambda _: _, 0), x)

then doing f(2) will run the infinite loop. How can I avoid that?

Was this page helpful?
0 / 5 - 0 ratings

Related issues

froystig picture froystig  路  34Comments

samuela picture samuela  路  27Comments

JuliusKunze picture JuliusKunze  路  23Comments

dwang55 picture dwang55  路  22Comments

ericmjl picture ericmjl  路  53Comments