Hi, there. I have read through the tutorial part of jax doc, and am still confused about how reverse AD is implemented in jax. I have the following specific questions.
As mentioned in defining new primitive documentation, one has to add transpose rules to primitives so that it can be reversely differentiated. But I cannot find transpose rules for most primitives in jax codebase, for example most linalg ops like svd has no transpose rules defined as far as I can see. Though grad(svd) works perfectly.
The function defining transpose rules is somehow the same thing as vjp function (cotangent of y to cotangent of x), is this understanding right? If so, what is the gain of defining transpose rules instead of directly defining vjp rules for each primitives?
What happens when defining customize jvp? How can reverse AD works automatically with only jvp rules? To be more specific, how to understand the following code?
@custom_jvp
def f(x):
return 0.
@f.defjvp
def f_jvp(primals, tangents):
return 0., tangents[0]**3.
g = grad(f)
g(2.)
# NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'pow' not implemented
# reverse AD fails
@custom_jvp
def f1(x):
return 0.
@f1.defjvp
def f1_jvp(primals, tangents):
return 0., tangents[0]
g1 = grad(f1)
g1(2.)
# 1, reverse AD works
def f2(x):
return x**3.
g2= grad(f2)
g2(2.)
# 12
If there is no transpose rule for pow, how reverse AD works on f2?
Suggestions on relevant part of documentation/notes/source code are more than welcome.
Ok, after some explorations, I have some rough ideas and summarize them here.
Transposes rules applies to jvp functions, it is designed for primitives with tangent as some of the inputs and asks how in reverse AD settings an output cotangent will evolve back to the cotangent of inputs. In this way, only very few primitives need transpose rules. Since in general, tangent part is always linear in jvp settings. In SVD case, it means that it is enough to only define transpose rules for matmul (though I still cannot find.) If this is the case, the gain is obvious compared to implementing vjp rules for all primitives.
As for the adding new primitive doc, the transpose rule definition is unnecessary if in jvp rules, we implement it as * and + instead of the new primitive itself.
The above partially answers 1 and 2.
As for 3, for a plain pow function f2, its jvp rules in terms of tangent part is a plain * instead of another pow, so it is ready to be transposed and supports reverse AD. This is in contrary to the dt^3 case in f where pow is directly involved with tangent.
from jax import custom_vjp, custom_jvp, defvjp, grad, jacfwd, jacrev
# customize jvp
@custom_jvp
def f2(x):
return x
@f2.defjvp
def f2_jvp(primals, tangents):
(x,) = primals
(x_dot,) = tangents
x = f2(x)
return x, x_dot
g2 = jacrev(f2)
g2(2.)
# works
g2 = jacfwd(f2)
g2(2.)
# works
# customize vjp
@custom_vjp
def f1(x):
return x
def f1_fwd(x):
x = f1(x)
return x, x
def f1_bwd(x, tangents):
dy = tangents
return (dy,)
f1.defvjp(f1_fwd, f1_bwd)
g = jacrev(f1)
g(2.)
# works
g = jacfwd(f1)
g(2.)
# fails with error: NotImplementedError: Batching rule for 'custom_lin' not implemented
Some quick comments on your questions:
1) the informal rule is you define JVPs for non-linear primitives, and transpose rules for linear primitives.
In general any expression that appears in a tangent calculation that isn't a function of primal values alone must be linear and to support reverse-mode AD must have a transposition rule with respect to the tangent argument. This allows us to build reverse-mode AD from forward-mode JVP rules. An interesting example to consider is something like div which is linear in its first argument but nonlinear in the second.
2) we use this design since it is our belief that writing a JVP rule and (occasionally) a transpose rule is simpler than writing both a JVP rule and a VJP rule. Only a small number of primitives need transpose rules. And in general JVP rules seem simpler to write than VJP rules ("dual number arithmetic").
3) Have you also seen:
https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
Does that answer your questions?
@hawkinsp , thanks, your answer helps a lot. And this confirms my observation in reply1 here.
I have two remaining questions:
1) 4 in my reply2: can you confirm whether the following statement is correct:
jax can automatically handle reverse AD if jvp is customized, but jax cannot handle forward AD automatically for functions with vjp customized
2) In linalg operations, tangent part often enters in matmul, but I cannot find transpose rules for matmul by a quick search. Can you point out how matmul transpose rules is handled in jax?
jax can automatically handle reverse AD if jvp is customized, but jax cannot handle forward AD automatically for functions with vjp customized
Correct, but reverse AD is free when defined by custom_jvp only if the function is non-linear, so it doesn't appear in any JVP rules. Otherwise (for linear functions) you need a transpose rule.
- In linalg operations, tangent part often enters in matmul, but I cannot find transpose rules for matmul by a quick search. Can you point out how matmul transpose rules is handled in jax?
JAX's matmul is defined in terms of the lax.dot_general primitive, which has transpose rules defined on these lines:
https://github.com/google/jax/blob/d55ea510e29540f6e3ba79141bf5f6cceefc20d4/jax/lax/lax.py#L2537-L2558
@shoyer , thanks a lot! Your reply helps and solves my questions.