Jax: Add a "How to think in JAX" doc

Created on 19 Oct 2020  路  5Comments  路  Source: google/jax

Over the past year, I think there has been a bit of a transition in the usage of JAX-flavored numpy from JAX is a drop-in replacement for numpy to JAX is a tool to use beside numpy. This has manifested in many different ways (the switch from import jax.numpy as np to import jax.numpy as jnp, the deprecation of silent conversions to array within aggregates, some aspects of omnistaging, etc.) but I think we've now gotten to the point where using JAX effectively requires thinking about which operations should be staged (via jnp routines) and which should not be (via np routines).

I think it's important that we provide some entry-point into developing the correct mental model for using JAX effectively. I think this probably should take the form of a new narrative doc, which would absorb some aspect of the existing Sharp Bits doc.

A quick brainstorm of key points we should cover:

  • JAX is not a drop-in replacement for numpy, but is a numpy-like interface for staged computations.
  • Use jnp for operations that you want to be staged/optimized, and np for "compile-time" operations
  • Common example: jnp.prod(shape) vs. np.prod(shape). Why should you use the latter?
  • Key difference: JAX arrays are immutable, so x[idx] = y does not work. Instead x = x.at[idx].set(y) & related
  • JAX and dynamic shapes: can be staged but not compiled.
  • Other points?
documentation

Most helpful comment

It might also be interesting to discuss

  • static vs nonstatic attributes in pytrees, jitted functions, and custom derivatives, and
  • imagining the propagation through a single code path of both primals and cotangents.

All 5 comments

I think this would be great.

Do you have any preliminary pointers on when to use jnp vs np?

Briefly: use jnp when you want the calculation to be compiled / to be performed on the accelerator. Use np when you want the calculation to happen on the CPU / at compile time.

Another point to touch on is the idea of fixed size iterative algorithms over recursive. It's the main challenge in coding complicated algorithms with JAX. It's what hinders many people I presume, and I know I have had to redesign an algorithm several times to make it work with JAX. But the results are amazing in terms of speed improvments when you do this.

Another point might be about reducing how many cond operations you use to let accelerators work as effectively as possible. I.e. somethings it's better to use a where and compute both branchs of a switch rather than use a cond. Knowing when to do that typically requires profiling, but there are some good rules of thumb.

It might also be interesting to discuss

  • static vs nonstatic attributes in pytrees, jitted functions, and custom derivatives, and
  • imagining the propagation through a single code path of both primals and cotangents.

A relevant user question: https://github.com/google/jax/issues/5280#issuecomment-752702613

Was this page helpful?
0 / 5 - 0 ratings

Related issues

froystig picture froystig  路  34Comments

proteneer picture proteneer  路  22Comments

shoyer picture shoyer  路  35Comments

ricardobarroslourenco picture ricardobarroslourenco  路  35Comments

shyoshyo picture shyoshyo  路  26Comments