Over the past year, I think there has been a bit of a transition in the usage of JAX-flavored numpy from JAX is a drop-in replacement for numpy to JAX is a tool to use beside numpy. This has manifested in many different ways (the switch from import jax.numpy as np to import jax.numpy as jnp, the deprecation of silent conversions to array within aggregates, some aspects of omnistaging, etc.) but I think we've now gotten to the point where using JAX effectively requires thinking about which operations should be staged (via jnp routines) and which should not be (via np routines).
I think it's important that we provide some entry-point into developing the correct mental model for using JAX effectively. I think this probably should take the form of a new narrative doc, which would absorb some aspect of the existing Sharp Bits doc.
A quick brainstorm of key points we should cover:
jnp for operations that you want to be staged/optimized, and np for "compile-time" operationsjnp.prod(shape) vs. np.prod(shape). Why should you use the latter?x[idx] = y does not work. Instead x = x.at[idx].set(y) & relatedI think this would be great.
Do you have any preliminary pointers on when to use jnp vs np?
Briefly: use jnp when you want the calculation to be compiled / to be performed on the accelerator. Use np when you want the calculation to happen on the CPU / at compile time.
Another point to touch on is the idea of fixed size iterative algorithms over recursive. It's the main challenge in coding complicated algorithms with JAX. It's what hinders many people I presume, and I know I have had to redesign an algorithm several times to make it work with JAX. But the results are amazing in terms of speed improvments when you do this.
Another point might be about reducing how many cond operations you use to let accelerators work as effectively as possible. I.e. somethings it's better to use a where and compute both branchs of a switch rather than use a cond. Knowing when to do that typically requires profiling, but there are some good rules of thumb.
It might also be interesting to discuss
A relevant user question: https://github.com/google/jax/issues/5280#issuecomment-752702613
Most helpful comment
It might also be interesting to discuss