jax.lax.scan benefits apart from compilation time

Created on 19 Dec 2019  路  2Comments  路  Source: google/jax

Dear jax team,

I wonder whether there are any benefits in using jax.lax.scan other than reduced compilation times? Or maybe even drawbacks in situations where a simple loop would be manageable.

I am looping over the dimensions of an input array, which are never more than 3. I have not gotten around to porting the code to scan and am wondering whether it's worth it/detrimental.

If that information is already available somewhere, sorry!

documentation

Most helpful comment

The main reason to use scan rather than an unrolled loop is to avoid compilation time (and tracing time, autodiff time, code size, etc.) that's linear (or worse) in the loop's trip count鈥攕o it isn't so much about lower compilation time as about _asymptotically lower_ compilation time 馃檪.

If you're looping over only three things, scan almost certainly won't help you (and often comes with a runtime penalty over unrolling, especially on GPU).

All 2 comments

The main reason to use scan rather than an unrolled loop is to avoid compilation time (and tracing time, autodiff time, code size, etc.) that's linear (or worse) in the loop's trip count鈥攕o it isn't so much about lower compilation time as about _asymptotically lower_ compilation time 馃檪.

If you're looping over only three things, scan almost certainly won't help you (and often comes with a runtime penalty over unrolling, especially on GPU).

Thanks @jekbradbury

Was this page helpful?
0 / 5 - 0 ratings

Related issues

sursu picture sursu  路  3Comments

DylanMuir picture DylanMuir  路  3Comments

sussillo picture sussillo  路  3Comments

madvn picture madvn  路  3Comments

clemisch picture clemisch  路  3Comments