We're beginning to acquire some questions, some of which have been asked frequently. Collecting in this issue for eventual inclusion in a more polished markdown file.
grad something non-differentiable?XYZ not implemented for ABC error (grad, vmap, jit)I'd like to add a question, @alexbw!
(Forgive me if I have any misconceptions here about what JIT-compilation is all about... if my question "isn't even wrong", please let me know!)
That is a good question! Right now you have to use currently undocumented
lax.cond and lax.while constructs that look and act like TF's cond
and while_loop. We are thinking about how we can remove this constraint
and automatically convert ifs and fors for you. Would that be useful for
you?
On Thu, Dec 20, 2018 at 10:05 AM Eric Ma notifications@github.com wrote:
I'd like to add a question, @alexbw https://github.com/alexbw!
- Does JIT-compilation compile the for-loops in my code, or does it
only compile the array computations to GPU/TPU?(Forgive me if I have any misconceptions here about what JIT-compilation
is all about... if my question "isn't even wrong", please let me know!)—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/58#issuecomment-449028180, or mute
the thread
https://github.com/notifications/unsubscribe-auth/AAJ4j9rCwsZqGTFWYeMkHOG_Nam7HRMzks5u66crgaJpZM4ZNsTh
.
@alexbw yes, it would! For most Python programmers, it is much more natural to write for-loops. Good to know that the cond and while constructs exist, but yes, for and if would play nicely with Pythonic conventions!
There's something to tease apart here: you can put for-loops in @jit functions and they'll compile just fine if the loop bounds don't depend on the values of the arguments to the function (but instead only depend on their shapes, or some other fixed values). The XLA code that gets compiled will have those loops unrolled.
This lax.while business is about compiling XLA code that itself has loop constructs in it, rather than unrolled loops. You need that if, for example, the loop exit condition depends on the value of an input argument to the @jit function. Even if the loop bounds were static, it can be preferable to generate compiled loop constructs because they might reduce compile times.
I think @alexbw's answer was about the latter case (generating XLA code with loop constructs in it), but I want to underscore that the former case (unrolling Python loops into XLA code) works without using any special constructs.
It's probably also useful to underscore that:
@jit on a function with Python control flow that depends on the values of the @jit function arguments, you'll get a loud error. No silent failures!jit and vmap. Automatic differentiation using functions like grad don't have any of these constraints, and so that works just like in Autograd with no need for special control flow constructs.@mattjj thanks for the helpful response, and hello from Cambridge! :smile:
Yes, mattjj@, thank you for the clarification, my initial response was
inaccurate because it was incomplete.
On Thu, Dec 20, 2018 at 10:39 AM Eric Ma notifications@github.com wrote:
@mattjj https://github.com/mattjj thanks for the helpful response, and
hello from Cambridge! 😄—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/58#issuecomment-449039619, or mute
the thread
https://github.com/notifications/unsubscribe-auth/AAJ4j1KjhrFnemENkTsmAVtvmQiaC0Ggks5u6687gaJpZM4ZNsTh
.
Hi, I would also like to add a question, which is probably related to the first two:
What is stax? How complete is it supposed to be? How does it compare to tensorflow/pytorch/other commonly used DL frameworks?
Here's a relevant paragraph from the README:
JAX provides some small, experimental libraries for machine learning. These libraries are in part about providing tools and in part about serving as examples for how to build such libraries using JAX. Each one is only a few hundred lines of code, so take a look inside and adapt them as you need!
To that end, stax is a minimalistic library for building neural networks. It's not meant to be complete, and that's the main way it compares to other libraries: it's pretty limited. But its power-to-weight ratio is pretty high: it's only a couple hundred lines of code!
Many users have found it to be a useful starting point in writing their own libraries. There are some already open-sourced, like trax which has a lot more capabilities, more models, and more features, and I can tell you there are several more being developed by users inside Google that I suspect will be open-sourced soon.
The JAX core team is planning to make a more complete yet still minimalistic library over the next several months. It might not ever include all the bells and whistles that other libraries do, but we think we can do more than just stax and have it be both a useful tool and a jumping-off point that inspires how others write libraries.
WDYT?
Huuuum, I think I got it, though it is an uncommon goal (compared to most other dl libraries). So, are you still accepting contributions to stax? Or do you see it as something that only the core team should/will touch?
We're definitely accepting contributions and bugfixes, but we're also fine with people who rely on the current version of stax making their own copies of stax.py, especially if you want to add many layers or features.
We have an FAQ now! It's not exactly what was outlined here, but, well, we can grow it as we find new frequently asked questions.