I am interested in training recurrent networks for which the transition dynamics have some sort of time-dependence. For example, the network might evolve linear from time t1=0 to time t2 and is clamped at some constant parameter array u from then on. In normal python code I might write some thing like this
for step in range(n_steps):
x = a.dot(x) if step < t2 else u
I would like to differentiate through these dynamics using reverse-mode, so I've been trying to use lax.scan.
However, I'm not sure how to introduce time-dependence into the scanning function f. Right now, I've defined two transition functions f1 and f2 one for each of the two cases:
carry, _ = lax.scan(f1, x0, length=t2)
carry, _ = lax.scan(f2, carry, length=n_steps - t2)
This would get quite annoying when my transition dynamics is much more complicated.
Therefore, I was wondering if it would be possible to have a function lax.scani which takes a scanning function f with type signature f : int -> c -> a -> (c, b) where the first argument of f is the index of the element it is scanning; and importantly, we can use this integer index to do control flow. In the example above, we would have
def f(t, carry, x):
return a.dot(carry) if t < t2 else u
carry, _ = lax.scani(f, x0, length=n_steps)
Thanks for the question!
One way to write it is like this:
def f(carry, i_x):
i, x = i_x
...
carry, ys = lax.scan(f, init_carry, (np.arange(n_steps), xs))
but then you couldn't use Python control flow on i in the body of f, and you'd need to use lax.cond instead.
Would the dependence on i be arbitrary, or is there some regularity to it?
Thanks for the fast response. I've considered doing what you suggested, but the inability to do control flow on i was the main reason that I didn't.
I wasn't aware of the function lax.cond. Would I be able to do control flow on i using lax.cond then? A use case I have in mind is
x = a.dot(x) if i > 0 else x
I'm not sure if this is considered arbitrary.
Thanks again for your help!
Can't you put the time into your carry, and increment it in f?
Hi Neil, thanks for the suggestion - I certainly can. I guess the problem I have now is just that I need to figure out how to use lax.cond to do control flow on the time index i in a way that is differentiable, as @mattjj suggested above. This I haven't really explored.
@tachukao yes, using lax.cond the control flow you write can always be staged out (i.e. by jit, or use in a scan body) and also differentiated. It's awkward, but it's the only robust way we've found to embed structured control flow in Python.
You can always avoid all this structured control flow stuff (lax.scan, lax.cond, etc) and write things with regular Python for-loops and ifs. JAX can differentiate native Python! But if you use jit on a Python loop, compile times may get long (because the loop is essentially unrolled into the XLA computation). (The purpose of lax.scan is to stage out a loop construct to XLA (without unrolling) and thus give good compile times.)
Here's sketch code for how you might write it so that the loop and other control flow stays in Python, but you can still use jit on some parts:
from functools import partial
from jax import jit
@jit
def f(params, hidden, x):
...
@jit
def g(params, hidden, x):
...
...
def rnn(params, hidden, inputs):
for i, x in enumerate(inputs):
if i % 10 == 0:
hidden, y = f(params, hidden, x)
elif i % 10 == 1:
hidden, y = g(params, hidden, x)
elif ...
outputs.append(y)
return hidden, outputs
You only need to write things in terms of lax.scan/lax.cond if you need more performance because you want to jit the whole rnn function.
If we introduced a lax.scani kind of function, it'd just be a wrapper around lax.scan and lax.cond, but our policy is to avoid wrappers unless they're very commonly needed.
I think we covered the original question, so I'm going to close this issue (otherwise we'll drown in issues!), but please open a new one if you have new questions!
Thanks @mattjj! That makes a lot of sense 馃憤