Consider the following code
from jax import grad
import jax.numpy as np
x = np.zeros(shape=(4,4))
d = grad(np.linalg.det)
print(d(x))
the derivative of np.linalg.det(x) against x is obviously a zero matrix, but NaNs was outputted
I guess my replies on the equivalent Autograd issue are relevant here too.
This doesn't raise an exception based on values (most JAX API's cannot). Instead I see a matrix of all NaNs:
[[nan nan nan nan]
[nan nan nan nan]
[nan nan nan nan]
[nan nan nan nan]]
This doesn't raise an exception based on values (most JAX API's cannot). Instead I see a matrix of all NaNs:
[[nan nan nan nan] [nan nan nan nan] [nan nan nan nan] [nan nan nan nan]]
Yes, you're right. Sorry for my fault, I wrongly tested autograd as if I were testing jax.
Should we just add a np.where / lax.select where the predicate is whether the determinant is zero?
Should we just add a np.where / lax.select where the predicate is whether the determinant is zero?
Sure but we need to work out what formula to use to compute the derivative in the case det(x) == 0, because IIUC the derivative is not necessarily zero in that case.
The derivative of the determinant is equal to the adjugate matrix, which, as someone on SO points out, can be computed using the SVD, even in the case det(x) == 0. Presumably we want to avoid doing that computation unless det(x) == 0, not sure if we're able to do that automatically though (because it sounds like value-dependent control flow).
We could have a kwarg on the det function which enables correct derivatives for det(x) == 0 at the cost of computing the svd of x.
Actually there might be an easier fix, just by using the LU directly to compute det instead of doing it via slogdet, will have a go at implementing.
That doesn't work, because the lu jvp also produces NANs for singular input. I thought you might be able to fix simply by defining
# WARNING: you also need to compute the correct sign of the determinant which I haven't
# bothered to do there.
@_wraps(onp.linalg.det)
def det(a):
lu, _ = lax_linalg.lu(a)
diag = np.diagonal(lu, axis1=-2, axis2=-1)
return np.prod(diag)
instead of the current definition in terms of slogdet.
Perhaps we should aim to fix this on the lu level then, but I'm not sure whether the lu derivative is well defined for singular input (whereas I'm pretty confident the determinant derivative is well defined for singular input).
Seems like it might make sense to define some custom JVP rules?
e.g., using any of these identities from the matrix cookbook:

Direct derivative rules can often be much more efficient than differentiating the matrix factorization itself, at least they are for matrix solves.
The issue here is that the formula for the gradient of the determinant involves a product of the determinant and the matrix inverse. For a singular matrix, that would be basically 0*inf, which is why you get NaNs. I do have an implementation of the gradient of the determinant of a rank n-1 matrix that works directly with the LU decomposition. I don't think it works for generic low-rank matrices though.
Also, I'm sure this issue is unrelated to #2510.
The issue here is that the formula for the gradient of the determinant involves a product of the determinant and the matrix inverse. For a singular matrix, that would be basically 0*inf, which is why you get NaNs.
Good point. Just to throw out the first idea that comes to mind: would using a psuedo-inverse instead of an inverse make sense here?
Tried it, didn't work. Here's my code for the cofactor (transpose of the adjugate) that works for rank n-1 matrices:
from jax import lax
from jax import lax_linalg
from jax import ops
import jax.numpy.lax_numpy as np
import jax.numpy.linalg as linalg
def solve(a, b):
"""Compute cof(a)^T*b. Equivalent to det(a)*solve(a, b) for nonsingular mat.
This function borrows heavily from jax.numpy.linalg.solve and
jax.numpy.linalg.det to compute the gradient of the determinant in a way that
is well defined even for rank n-1 matrices.
* assumes a is at least rank n-1
* assumes u_{nn} is the element set to 0 in singular cases.
Args:
a: A square matrix or batch of matrices, possibly singular.
b: A vector/matrix, or batch of vectors/matrices of the same dimension as a.
Returns:
cofactor(a)^T*b, aka adjugate(a)*b
"""
a, b = linalg._promote_arg_dtypes(np.asarray(a), np.asarray(b)) # pylint: disable=protected-access
a_shape = np.shape(a)
b_shape = np.shape(b)
a_ndims = len(a_shape)
b_ndims = len(b_shape)
if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2] and b_ndims >= 1):
msg = ("The arguments to cofactor_solve must have shapes "
"a=[..., m, m] and b=[..., m, k] or b=[..., m]; got a={} and b={}")
raise ValueError(msg.format(a_shape, b_shape))
if a_shape[-1] == 1:
return b
# lu contains u in the upper triangular matrix and l in the strict lower
# triangular matrix.
# The diagonal of l is set to ones without loss of generality.
lu, pivots = lax_linalg.lu(a)
dtype = lax.dtype(a)
m = a_shape[-1]
# Numpy treats the RHS as a (batched) vector if the number of dimensions
# differ by 1. Otherwise, broadcasting rules apply.
x = b[..., None] if a_ndims == b_ndims + 1 else b
batch_dims = lax.broadcast_shapes(lu.shape[:-2], x.shape[:-2])
x = np.broadcast_to(x, batch_dims + x.shape[-2:])
lu = np.broadcast_to(lu, batch_dims + lu.shape[-2:])
# Compute (partial) determinant, ignoring last diagonal of LU
diag = np.diagonal(lu, axis1=-2, axis2=-1)
parity = np.count_nonzero(pivots != np.arange(a_shape[-1]), axis=-1)
sign = np.array(-2 * (parity % 2) + 1, dtype=dtype)
# partial_det[:, -1] contains the full determinant and
# partial_det[:, -2] contains U_{nn} / det{U}.
partial_det = np.cumprod(diag, axis=-1) * sign[..., None]
lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2])
permutation = lax_linalg.lu_pivots_to_permutation(pivots, m)
permutation = np.broadcast_to(permutation, batch_dims + (m,))
iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims + (1,)))
x = x[iotas[:-1] + (permutation, slice(None))]
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True,
unit_diagonal=True)
x = ops.index_update(x, ops.index[..., :-1, :],
x[..., :-1, :] * partial_det[..., -1, None, None])
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
return x[..., 0] if a_ndims == b_ndims + 1 else x
In the case that the matrix rank is less than n-1, the gradient of the determinant will be identically zero. So we could add a lax.cond to the above function that checks if there is more than one zero on the diagonal of lu. The reason I hadn't done this yet was that I was not sure what tolerance triangular_solve uses to determine if a diagonal element is zero or not (presumably we should match that tolerance).
Just a note RE "This function borrows heavily from jax.numpy.linalg.solve". I updated the implementation of jax.numpy.linalg.solve a month or so ago in https://github.com/google/jax/pull/2220, so it looks pretty different now. That probably would speed up gradients of your function, not sure if would change the numerics.
Thanks, I'll take a look.
Also, looking at triangular_solve, it seems like there are different primitives for different backends. Does anyone have any idea how to safely check if the backend considers the matrix singular? Should we just do a try/catch that returns zero if the backend considers the matrix to be singular?
Though maybe you misunderstand my comment in the code. It borrows heavily from the forward computation of solve, not the gradient of solve. Did that change at all?
Ah I guess you said it might change the gradient of my solve. Since solve here computes the gradient of the determinant, that would be the gradient of the gradient in my case.
Good point, my change probably isn't relevant for your gradient rule.
On Wed, Mar 25, 2020 at 11:41 AM David Pfau notifications@github.com
wrote:
Ah I guess you said it might change the gradient of my solve. Since solve
here computes the gradient of the determinant, that would be the gradient
of the gradient in my case.—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/2380#issuecomment-604016470, or
unsubscribe
https://github.com/notifications/unsubscribe-auth/AAJJFVS5VTNQQPSEX2BWJHLRJJF5HANCNFSM4LDZE7OQ
.
I do however need second derivatives of the determinant, so it may be useful for me after all
I thought the following would be a nice quick work-around, the idea being to let the jvp rule of prod handle the awkwardness of differentiating a product of many terms. But it doesn't seem to support second derivatives yet (see below).
import jax.numpy.linalg as la
import jax.numpy as np
def correct_derivatives_det(x):
_, s, _ = la.svd(x)
return np.prod(s)
When you try to compute a second derivative you get
NotImplementedError: Forward-mode differentiation rule for 'reduce_window' not implemented
That's an issue that's already documented here, maybe this will provide a bit of extra motivation to implement that rule...
EDIT: To be clear I'm suggesting the above as a temporary work-around not a permanent fix, I think the correct thing to do is probably to fix the derivative of lu and then implement det in terms of lu (without bothering with a custom det derivative), because afaict that would be straightforward and pretty fast. There are also cholesky and qr if lu turns out to be complicated to fix.
The function I shared above already works. I'm just working on integrating it into JAX.
I'd be interested to know if that approach is faster than using the derivative of the lu decomposition. In the long term we should aim for an approach that works for all matrix ranks and differentiation orders (since that shouldn't be too difficult once we have 2nd and higher order np.prod derivatives).
Just thinking also that defining det directly in terms of lu/svd/cholesky will be less numerically stable than using slogdet (just checked and Numpy computes det using slogdet). So that's another reason to maybe prefer keeping the current det implementation and adding a custom jvp.
I'm actually having a weirdly difficult time reproducing this bug. For anything other than an identically zero matrix, even if the determinant is still zero to within numerical precision, the gradient often still works fine. I think this is due to loss of numerical precision that, weirdly, helps us in this case. Still, I've managed to find a few cases where the existing implementation returns NaN and the new version works.
I've got a sort-of-working implementation now in my forked repo at github.com/dpfau/jax (see linalg.py and linalg_test.py). At the moment the issues are:
*I can't simultaneously define a custom_jvp and custom_vjp rule, and if I try to only use a custom_jvp rule, certain ops can't be transposed (I think scatter is the first to fail)
*Something seems to be failing in the higher-order derivatives as well
However, it sounds like the JAX team is still sorting out some other bugs involved with the new custom derivatives, so I'll wait until that is sorted out.
I think https://github.com/google/jax/pull/2597 being merged should mean that the code in https://github.com/google/jax/issues/2380#issuecomment-604359486 will work for all orders of derivative and all matrix ranks 😊
This is now fixed by PR #2809. Please close this issue.
Woohoooooo thanks @dpfau!
Most helpful comment
The function I shared above already works. I'm just working on integrating it into JAX.