Is there a recurrent layer in the pipeline in stax? I am excited to try out JAX but I mostly work with sequence data like amino acid sequences where LSTMs are bread and butter. I couldn't find the api for that in stax docs.
I don't know of anyone actively working on a recurrent layer for stax, though I know others have expressed interest too. Some people have written RNNs though.
Stax is just an experiment, meant to show how easy it is to write layers libraries and to serve as inspiration. It's not meant to be exhaustive. The set of JAX examples (both using and not using stax) will grow in the future, but I encourage you not to feel limited by stax: if you can implement an LSTM in raw NumPy, you can use JAX!
Let's leave this issue open so that we can check in some kind of RNN example in the future.
Aah, okk, I guess I had a different perspective in terms of the use case of JAX? For research in say, a field like computational biology, surely I am not going to use my own LSTM implementation over something like pytorch which is heavily optimized. Pytorch also lets you build heavily customized architectures. So, I am trying to think the place of JAX among all of the deep learning frameworks.
This feels to me like something that would be in tensor2tensor, they already have an LSTM model it looks like but they're currently missing a jax version: https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/trax
One benefit I found with JAX is that the JIT compilation can make your own custom code pretty competitive. I was messing around with my own LSTM implementation and I found that JAX actually had the fastest forward pass on GPU after JIT compilation (comparing against using the same low-level ops in Pytorch and Tensorflow). My implementation uses lax.scan, however, which was recently removed and hasn't been added back in just yet, so I'm not sure how fast JAX computes gradients compared to the other frameworks.
The reference LSTM code can be found here, if it is any help.
surely I am not going to use my own LSTM implementation over something like pytorch which is heavily optimized.
That's a great point to bring up!
JAX gets its performance from XLA. By relying on an end-to-end optimizing compiler like XLA, the idea is that you don't need to depend on specialized implementations for performance; you can write your own code from scratch and XLA will heavily-optimize it for you. That's the model at least, and we've been pretty impressed with XLA performance so far.
That said, XLA is still developing rapidly. If you think it's not giving you great performance on some application, open a bug with example code and we and/or the XLA performance team will be glad to dig in.
Now that there's a differentiable scan, I finally have a working toy LSTM example!
Woohoo, awesome @sharadmv ! Thanks for sharing that.
Any initial findings about scan? Thoughts on ergonomics, performance improvements to make, etc?
These swapaxes calls are interesting to me, since I'm currently working on vmap of scan. Ergonomics-wise, how valuable would it be to scan over different axes? (Not sure if we'd add that because I'm not sure about the power-to-weight, but I wanted to solicit early feedback!)
I actually think scan is very clean right now, and the docstring explains it very well! As far as performance goes, I'm finding that JIT-compiling gradient updates results in a 6-7x speed improvement! I haven't explored different sequence lengths just yet, but definitely something I'll look into soon.
Most scan implementations I've worked with (Theano, Tensorflow) only loop over the first axis (you can see how I handle it here). I don't think it's necessary to have the loop axis as an argument, but it would be as minor QOL improvement.
Most helpful comment
Now that there's a differentiable scan, I finally have a working toy LSTM example!