Dear jax team,
I wonder whether there are any benefits in using jax.lax.scan other than reduced compilation times? Or maybe even drawbacks in situations where a simple loop would be manageable.
I am looping over the dimensions of an input array, which are never more than 3. I have not gotten around to porting the code to scan and am wondering whether it's worth it/detrimental.
If that information is already available somewhere, sorry!
The main reason to use scan rather than an unrolled loop is to avoid compilation time (and tracing time, autodiff time, code size, etc.) that's linear (or worse) in the loop's trip count鈥攕o it isn't so much about lower compilation time as about _asymptotically lower_ compilation time 馃檪.
If you're looping over only three things, scan almost certainly won't help you (and often comes with a runtime penalty over unrolling, especially on GPU).
Thanks @jekbradbury
Most helpful comment
The main reason to use
scanrather than an unrolled loop is to avoid compilation time (and tracing time, autodiff time, code size, etc.) that's linear (or worse) in the loop's trip count鈥攕o it isn't so much about lower compilation time as about _asymptotically lower_ compilation time 馃檪.If you're looping over only three things, scan almost certainly won't help you (and often comes with a runtime penalty over unrolling, especially on GPU).