First of all, thank you for supporting complex numbers! Other similar tools have spent years on this this problem and still don't have it working. I was wondering if there's a good way to solve the problem below.
Background: The _exponential families_ are important families of probability distributions. All exponential families are identified by a real-valued scalar function called a _log-normalizer_ (among other things). The gradient of the log-normalizer is also very important. In particular, it shows up when optimizing a predictive distribution. One example of a gradient log-normalizer is the logistic sigmoid function that shows up when optimizing a Bernoulli model.
The problem is that some exponential families like the complex normal distribution have a complex-valued gradient log-normalizer. This causes a problem with custom_jvp:
from jax import grad, custom_jvp
import jax.numpy as jnp
def f(q):
return jnp.real(jnp.sum(q))
def nat_to_exp(q):
return q
if True: # When false, no errors are produced.
f = custom_jvp(f)
@f.defjvp
def f_jvp(primals, tangents):
q, = primals
q_prime, = tangents
a = nat_to_exp(q)
return f(q), a * q_prime
some_q = jnp.array([2+1j, 1+1j], dtype=jnp.complex64)
print(grad(f)(some_q))
gives
TypeError: mul requires arguments to have the same dtypes, got float32, complex64.
What is the best way to fix this? One way would be to make the log-normalizer complex-valued, but this is unfortunate from both a usability standpoint (it shows up in many calculations that are real-valued like the cross entropy, and the density), and from an efficiency standpoint (complex numbers need twice the memory, etc.)
I'm wondering if it's possible to relax the check in JVP evaluation to make the multiplication support widening from real to complex?
If there's one thing I love, it's exponential families. It's no joke to say an original motivation for JAX was to do research involving exponential families (particularly this work and this work). So you've got allies here! Though TBH I've never worked with exponential families involving complex numbers, so I think I'm going to learn some things here.
EDIT: and this talk!
I think the problem is that the primal and tangent output are different types. This issue isn't specific to custom_jvp per se; it'd also be an error if you were working with Primitives.
I think that type error is real, not an implementation bug: it doesn't make sense for the output primal to be a real while the output tangent is a complex value, because that'd be like saying the tangent space of a real vector space is a complex vector space. More concretely, the output tangent vector is an answer to "what happens to the output of f if I wiggle the input a little bit in this direction", and if the final step of f is to apply np.real then I don't think the answer can be "it moves a little in the direction of this complex number."
I think the tangent output must be a real number. Perhaps just replacing a * q_prime with np.real(a * q_prime)? (That's a bit of a guess, but if we had a more realistic example we could check the maths do what we want.)
Wow, great to see that we have some common interests!
Your point about tangents being in the same space is interesting, but I don't think we can replace the output of tangent with a real value since the tangents here represent parameters, and the parameters can be complex.
If you take a look at figure 3, here: https://github.com/NeilGirdhar/efax/blob/master/expfam.pdf, (a diagram that I modified from Nielsen and Nock), I think this illustrates what's going on. The log-normalizer g(x) is the height of a "bowl", the x-axis represents the parameter space. If the parameter space were complex, then x is a complex plane rather than an axis. Now, the gradients on that bowl can live in a space as wide as the parameters. Essentially, we have real-valued g(x), but x is complex, so dg(x)/dx is also complex. What do you think?
if we had a more realistic example we could check the maths do what we want.)
Okay, I just released my exponential family library for JAX here. The code that breaks is this test:
https://github.com/NeilGirdhar/efax/blob/master/efax/test/test_distributions.py#L49
which takes the gradient of the log-normalizer defined here:
https://github.com/NeilGirdhar/efax/blob/master/efax/complex_normal.py#L18
whose JVP is defined here:
https://github.com/NeilGirdhar/efax/blob/master/efax/exponential_family.py#L60
@mattjj
It's been about six months, and I looked at this a bit more carefully. I think I understand what's happening. The people who write about exponential families (e.g., Nielsen and Garcia) define the log-normalizer to be a function:
g(z): C -> R
they also define
g'(z): C -> C.
g is not holomorphic. This is not a complex derivative. The statisticians writing about exponential families are interpreting g' to mean dg/dx + i dg/dy where z = x+iy. This makes a lot of sense in general machine learning since g is part of the cross entropy, (a common loss function), and all we want to do is ask how to adjust z in order to minimize it.
This is very close to the design choice described in the autograd notebook whereby g'=dg/dx - idg/dy. Doesn't the autograd notebook disagree with your argument that: ": it doesn't make sense for the output primal to be a real while the output tangent is a complex value, because that'd be like saying the tangent space of a real vector space is a complex vector space"? It appears that the notebook is saying that I should be able to take the derivative of a real valued output of a complex function. Am I reading it wrong? If I can take such a derivative, why can't I provide a custom derivative?
@mattjj Sorry to ping for this again, but I just saw on https://github.com/google/jax/issues/4996 that you're suggesting:
To minimize a real-valued loss as a function of complex-valued inputs, you'd want to take steps in the negative elementwise-conjugate direction of the gradient (the reason is explained in the docs I linked above).
This is exactly what I'm trying to do in this issue. I was wondering if it would be okay to somehow allow the definition of custom VJPs for such functions?
Hm, there must be some miscommunication. I certainly agree that if f : C -> R then grad(f) : C -> C. But then we'd still have jvp(f) : (C, C) -> (R, R) (because jvp(f) models the function (x, v) \mapsto (f(x) ∂f(x) v), and ∂f(x) : C -> R). We define grad(f)(x) == vjp(f, x)[1](1.0), and vjp(f, x)[1] models the transposed linearized function, hence why it goes R -> C (thus being consistent with grad(f) : C -> C).
I need to reread this issue to page back in what's going on, but in the meantime, do you agree with the above?
Great summary!
∂f(x) : C -> R
Sorry, but I think this is where we disagree. What I want ∂f(x) to mean is the complex slope of the tangent plane at (x, f(x)). How do you visualize it? Unless I've misunderstood your notation, ∂f(x) therefore maps C ⟶ C. Consider f(x) = |x^2|/2, so ∂f(x) = x. WolframAlpha seems to define it this way.
That's why I believe jvp(f): (C, C) ⟶ (R, C).
What do you think?
Ah excellent progress, we're zeroing in on the issue.
I'm using ∂f(x) as a linear map, mapping from small perturbations to the input (hence taking values in the input tangent space) to small perturbations to the output (hence taking values in the output tangent space). Hence if f: R^n -> R^m then ∂f(x) : R^n -> R^m. Clearly the dimension of the codomain of ∂f(x) must equal the dimension of the codomain of f (i.e. both are m), so why would other parts of the type change (from R to C)?
By definition, we need f(x + v) = f(x) + ∂f(x) v + O(||v||^2), and the types only work out if ∂f(x) : C -> R here.
Or, to be really concrete about it, we can think about finite differences: ∂f(x) v must be approximate f(x + v) - f(x) as we make v very small. But that implies ∂f(x) v has the same type as the codomain of f, namely R here. In code, this means that jvp(f, (x,), (v,))[1] must approximate f(x + v) - f(x) for small v, and that means jvp(f, (x,), (v,))[1] has the same type as f(x) and f(x + v).
(This is a bit tangential but: In the example f(x) = |x^2|/2, or the similar function f(x) = |x|^2/2, the Jacobian can't be represented as a single complex number because the function is not holomorphic (just as no function C->R can be holomorphic). (I'm not sure if in your example you meant |x^2|/2 or |x|^2/2, but because the latter is simpler and more canonical I'm going to work with that instead.) To write a Jacobian matrix for f(x) = |x|^2/2, we can use the natural identification of C with R^2 to define the corresponding function g(x, y) = (x+y) * (x-y) / 2 = (x^2 - y^2)/2, so that we can write the Jacobian as [ x; -y ]. But just as this Jacobian maps from R^2 to R, so too must the linear map ∂f(x) map from C to R.)
I recommend thinking in terms of finite differences for concreteness, so that jvp(f, (x,), (v,)) ≈ (f(x), (f(x + eps * v) - f(x))/eps) for small eps. Indeed we can check that directly:
In [1]: from jax import jvp
In [2]: eps = 1e-5
In [3]: import jax.numpy as jnp
In [4]: f = lambda x: jnp.abs(x)**2 / 2.
In [5]: x = 1+1j
In [6]: v = 3+4j
In [7]: (f(x), (f(x + eps * v) - f(x))/eps)
Out[7]: (DeviceArray(0.99999994, dtype=float32), DeviceArray(7.0154667, dtype=float32))
In [8]: jvp(f, (x,), (v,))
Out[8]: (DeviceArray(0.99999994, dtype=float32), DeviceArray(7., dtype=float32))
Ah excellent progress, we're zeroing in on the issue.
Agreed.
By definition, we need f(x + v) = f(x) + ∂f(x) v + O(||v||^2), and the types only work out if ∂f(x) : C -> R here.
I think this is where we disagree. (Maybe we're getting closer to the problem?)
First of all, v is complex, so I don't think the types do work out even if "∂f(x) : C -> R" as you suggest.
Since f: C -> R, we need to use the Taylor approximation based on directional derivatives. Let ∂f_k be the directional derivative: ∂f_k(x) = lim h->0 (f(x + hk) - f(x)) / h. Then the Taylor approximation is
f(x + v) = f(x) + ∂f_1(x) Re(v) - ∂f_i(x) Im(v) + O(||v||^2)
= f(x) + Re(conjugate ∂f(x) v) + O(||v||^2)
if you are okay with defining ∂f(x) = ∂f_1(x) + i∂f_i(x).
If you look at section 3.4 here, this appears to be how every reference I've seen on exponential families defines the gradient of the _log-normalizer_. The log-normalizer after all is real-valued. Its domain can be complex. And the _gradient log-normalizer_ has to map complex to complex in this case. They are using the above definition (∂f(x) = ∂f_1(x) + i∂f_i(x)).
If gradient log-normalizers don't fit into JAX's JVP architecture, then maybe it could fit the VJP architecture. It is weird to have to special case that though.
The types for the Taylor expansion I wrote work: ∂f(x) v means applying the linear map ∂f(x) : C -> R to the vector v in C to get a result in R. If you'd prefer we can write that as ∂f(x)(v), though I was using the convention that one can apply a linear map without writing the parentheses.
The log-normalizer after all is real-valued. Its domain can be complex. And the gradient log-normalizer has to map complex to complex in this case.
Yup, we agree that if f : C -> R then grad(f) : C -> C.
If gradient log-normalizers don't fit into JAX's JVP architecture, then maybe it could fit the VJP architecture. It is weird to have to special case that though.
These fit fine into JAX's JVP architecture; I think you're just forgetting about the transpose. That is, if f : C -> R, then ∂f(x) : C -> R, and ∂f(x)^mathsf{T} : R -> C.
Maybe it'd be useful to look at an example of a C->R primitive and how its JVP or transpose rule works? Here are some examples: abs_p, real_p, imag_p. You can also test for yourself that you can use grad on those functions just fine, e.g.:
In [1]: from jax import grad
In [2]: import jax.numpy as jnp
In [3]: grad(jnp.abs)(1+1j)
Out[3]: DeviceArray(0.70710677-0.70710677j, dtype=complex64)
Okay!! Thanks for taking the time to set me straight on this one. I repaired my JVP by just returning the real components (https://github.com/NeilGirdhar/efax/commit/1ae4b7fec6146bcda3dffd00be70d7ad97f2ae03) and that seems to pass my tests comparing it with scipy. I honestly can't say I totally understand it though, but I trust you.
Heh, thanks for the trust. Sorry if I'm not expressing things clearly... I realized I'm using nonstandard notation and ideas that we haven't written down from start to finish anywhere. (I prepared some slides on AD for a Stanford course recently, and those provide some clues about the notation I'm using, but they don't actually cover complex numbers at all...)
I'll take this as a +1 to having a careful exposition on complex numbers in any future AD docs we write (like the "Autodiff Cookbook Part 2" we've never gotten around to writing...).
Most helpful comment
If there's one thing I love, it's exponential families. It's no joke to say an original motivation for JAX was to do research involving exponential families (particularly this work and this work). So you've got allies here! Though TBH I've never worked with exponential families involving complex numbers, so I think I'm going to learn some things here.
EDIT: and this talk!