Awesome project! I I really like how you have been able to make all the state explicit and pass it in to the models, optimizers, and random number generators. The performance gains that I've seen from just adding a simple jit annotation to my functions is phenomenal. However, I am really curious if you have considered making a similar library in Haskell? It seems like with each new neural network library we are trying to bring functional programming concepts into the mainstream. From the tensorflow data api to now Jax, we are getting closer and closer to a functional language built on top of Python. It still seems like certain things like control flow aren't quite there, which is why we need cond and while, which are again Lisp/Haskell concepts. Couldn't we make a better API and experience, if rather than trying to shoehorn a functional paradigm into imperative languages, we actually go to where the functional style is the natural one? We can model stax.sequential as a monad and stax.parallel as an applicative. The Facebook Haxl library could help make parallelization natural and mostly automated too.
Thanks for the kind words and encouragement!
We're big fans of functional programming and Haskell. In fact, when @dougalm redesigned and rewrote JAX's core this past year, he first built a toy model in Haskell, which did wonders to clarify our thinking. (He had so much fun with it that for a while it was hard to get him to write Python again!)
JAX also draws on a lot of others' work in the functional programming community, like Oleg Kiselyov's Typed Tagless Final Interpreters (JAX interpreters are essentially untyped tagful final interpreters that leverage partial evaluation) as well as Conal Elliott's Compiling to Categories and The simple essence of automatic differentiation. (Someday we'll write a paper that explains the formal view of JAX and its connection to those ideas, but paper writing keeps getting bumped in favor of developing the software...)
While @dougalm is the best person to answer about Haskell specifics, I can think of a few high-level reasons that JAX makes sense in Python rather than Haskell for now. The first is mostly practical: Python has the users and related software stacks that we're interested in supporting and working with. The second is that we're exploring a research hypothesis around how much we can do in a dynamic language like Python at the user library level. The third, more technical reason might be the most compelling to you: we leverage Python's dynamic typing flexibility in quite a few places, and getting some parts of JAX to work with Haskell's type system might be pretty hard or impossible, even with all the language extensions Dougal was using. Dougal could comment much more coherently, though!
That said, I'd love to see a real Haskell package that draws from and builds on the ideas in JAX. If there's ever a place for functional programming to shine, numerical computing of the sort found in modern ML seems to be it.
This is an interesting topic for discussion, but for issue-tracking purposes, I'm going to close this issue. Please reopen if needed!
Is there any sort of ETA on that "formal view of JAX " paper?
Thanks for the interest!
The most recent ETA I can report is November 2019. It doesn't look like we'll make it.
We have made some progress, but two barriers are (1) we're still figuring out what JAX really is and (2) we're having too much fun building the software artifact.
What kind of paper would be most interesting to you? A PL-ish paper? Systems-y paper? A machine learning tools kind of paper?
Definitely something on the PL side of things!
For reference, I'm looking at adding automatic differentiation to Futhark and want to understand JAX better to be able to effectively compare the two/inform design decisions in Futhark.
Oh awesome! We are fans of Futhark, @dougalm has told me many awesome things about it.
JAX is an evolution and improvement of our AD ideas in Autograd, but I think Dex is a further evolution still. That makes sense, because @dougalm first designed Autograd, then designed JAX's core and AD, and now is designing Dex! So it might be that any ideas you could glean from JAX you can also glean from Dex, perhaps even cleaned up and more concisely presented. (There are some ideas in JAX related to embedding in Python that aren't present in Dex, though I'm not sure if they'd be interesting for Futhark.)
We'd still like to write down the AD ideas in JAX, but no ETA on that. Having +1s like this helps motivate us :)
Thanks for the Dex tip; I was unaware of it until now. I read the workshop paper (cool!) and took a gander at the source---nice to have some Haskell to reference (Futhark is also written in Haskell).
Anyway, I'll be sure to gobble up any paper that's written on JAX, so you have my +1 for sure.
I think for the HaskTorch people a write-up would be of interest as well. :)
Most helpful comment
Thanks for the Dex tip; I was unaware of it until now. I read the workshop paper (cool!) and took a gander at the source---nice to have some Haskell to reference (Futhark is also written in Haskell).
Anyway, I'll be sure to gobble up any paper that's written on JAX, so you have my +1 for sure.