@sguada suggested we add a CIFAR10 example. Some docs for how to use lax.conv and lax.conv_general_dilated are also in order, for people who don't want to use stax.py. We might also want some additional convenience wrappers, since lax.conv_general_dilated is super general, and lax.conv has some unusual dimension orderings.
Also, the Conv in stax should probably be called Conv2D.
What does lax stand for?
It stands for our general commitment to naming things with "ax" on the end, which is super funny.
I think the name arose as a transposition of the letters "x", "l", and "a", since it's mostly 1:1 with XLA HLO. It doesn't mean anything other than that, as far as I recall. @froystig might remember something though.
Hello @mattjj / @gnecula - I've recently started using JAX and would like to help add this. Is there a certain direction you'd want me to look at and get started.
If you think this is not useful then are there any other issues I can help you out with?
P.S. - If successful this will be my first PR in JAX.
I'm not certain whether or not this predates the issue, but there is some pretty detailed discussion of lax.conv in the JAX docs here: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Convolutions
This is pretty detailed actually. Thanks for sharing this @jakevdp
In that case this issue is redundant.
Unless you think we can enhance it further or create another version just focused on lax.conv (and other subclasses)
Other than that do you recommend any other issues that a JAX noob like me could look into?
One good place to start is #70: it lists some of the numpy and scipy functionality that remains unimplemented (though the list is getting pretty small these days!)
This is amazing. Let me look into these in-details. Thanks!
I think this is adequately covered by both the lax.conv_general_dilated reference documentation, the examples in the "Sharp bits" notebook, as well as the various neural network libraries built on top of JAX (e.g., Flax and Haiku).
Although if someone wanted to send a PR adding an example to the conv_general_dilated documentation, I'm sure that would be welcome!
Most helpful comment
It stands for our general commitment to naming things with "ax" on the end, which is super funny.
I think the name arose as a transposition of the letters "x", "l", and "a", since it's mostly 1:1 with XLA HLO. It doesn't mean anything other than that, as far as I recall. @froystig might remember something though.