The README says
If you want compiled control flow, use structured control flow primitives like lax.cond and lax.while.
But lax.cond does not actually exist, and lax.while only exists in the form of lax._while_loop.
Is lax.cond already implemented somewhere, just not part of the master branch? I don't need gradient support for lax.cond (which is tracked by PR #83), just jit-compilable cond would be a huge gain for me (currenty looking into whether I might be able to contribute this if it doesn't exist yet)
lax.while will probably never exist because while is a reserved keyword; so this should probably be renamed to lax.while_loop and can probably already be made available by renaming lax._while_loop, or what's the problem with the _while_loop implementation (fori_loop and foreach_loop are not prefixed with an underscore and use _while_loop, so it should be fully functional)?
In the meantime (before lax.cond is ready) you can potentially use "predication" (i.e. something like def cond(b, x, y): b * x + (1 - b) * y) if that doesn't add too much overhead.
@jekbradbury np.where(b, x, y) is also possible (aka lax.select with different broadcasting semantics).
Thanks @jekbradbury and @hawkinsp !
While we are proposing workarounds: a single-iteration while-loop should also work, but your solutions are even better :)
Early versions of lax.cond (circa early 2017) lowered into while loops (though I think we needed 2 in general, I can't remember why). But were totally broken from a tracing-composability perspective, which is why we don't have them now.
The main reason we haven't exposed a lax.cond yet is essentially the same as why #331 and #207 are outstanding issues, namely that we want to handle closures and arbitrary composability correctly. @dougalm designed the core system to handle these issues, and actually it enables two ways of handling higher-order functions like lax.cond and lax.while, which we can call "the hard way" and "the easy way". We recently decided that the practical benefits of the hard way over the easy way are pretty miniscule (though academically interesting), and so @dougalm started going "the easy way" in #334.
lax.cond but it paves the way for doing it (along with a differentiable lax.map and lax.scan).We'll get lax.cond in #415, probably with some limitations at first (e.g. no reverse-mode autodiff for now, as per the OP's request).
Most helpful comment
Early versions of
lax.cond(circa early 2017) lowered into while loops (though I think we needed 2 in general, I can't remember why). But were totally broken from a tracing-composability perspective, which is why we don't have them now.The main reason we haven't exposed a
lax.condyet is essentially the same as why #331 and #207 are outstanding issues, namely that we want to handle closures and arbitrary composability correctly. @dougalm designed the core system to handle these issues, and actually it enables two ways of handling higher-order functions likelax.condandlax.while, which we can call "the hard way" and "the easy way". We recently decided that the practical benefits of the hard way over the easy way are pretty miniscule (though academically interesting), and so @dougalm started going "the easy way" in #334.334 doesn't add
lax.condbut it paves the way for doing it (along with a differentiablelax.mapandlax.scan).