It seems like check_grads has incorrect scaling behaviour for large input tensors.
In the code of check_jvp (which is called by check_grads), the following line essentially performs a weighted-sum reduction over args.size values:
v_out, t_out = f_jvp(args, tangent)
The vector tangent has a length which grows with the dimension of args. This would warrant scaling of tolerances by args.size. However, such scaling is not performed. Instead, t_out is passed into
check_close(t_out, t_out_expected, atol=atol, rtol=rtol)
which will eventually call:
_assert_numpy_allclose(a, b, atol=atol * a.size, rtol=rtol * b.size)
Thus, the scaling will happen by t_out.size which is the height of the Jacobian, instead of the width args.size, along which the reduction actually happened. Moreover, _assert_numpy_allclose actually performs element-wise comparison with the given tolerances and does
cond = reduced.all()
reduction on a boolean array. So no need for tolerance scaling in this function. If tolerance scaling is necessary, it should happen before, not here.
This behaviour looks incorrect to me. Am I missing something?
Your analysis sounds right to me. By not scaling tolerances this way, check_grads is being too strict for functions with large input dimension; is that right?
What fix do you recommend?
I forgot to ask the main question I was wondering: is the scaling by the height of the Jacobian in _assert_numpy_allclose causing us to be effectively too loose with some tolerances? For example, is there a test that fails with some given atol/rtol setting, but starts passing if we just broadcast the output to have a taller Jacobian?
Both of your conclusions are true:
If we want a minimal fix, we could swap scaling/not scaling for check_jvp like this:
check_close(t_out, t_out_expected, atol=atol * args.size, rtol=rtol * args.size)
_assert_numpy_allclose(a, b, atol=atol, rtol=rtol)
md5-a6a25649c8a52510457e13d05deabd1b
Assume a == b within atol, rtol:
abs(a - b) = atol + rtol * b ("just within")
Now if we take a1 = 100 * a, b1 = 100 * b:
abs(a1 - b1) = 100 * atol + rtol * 100 * b = 100 * atol + rtol * b1 >> atol + rtol * b1
Is linear scaling the right way to do it? If it is, maybe still worth a comment in the source code so that readers can follow through?
Maybe worth to clarify for the users what kind of functions check_grads expects? I.e. vector/scalar input, vector/scalar output?
Finally, more a feature request than a fix: how hard would it be to add a full Jacobian checking? Then people could decide if they want a fast but less reliable check (jvp) or a slow but more reliable one (full jacobian).
Finally, more a feature request than a fix: how hard would it be to add a full Jacobian checking? Then people could decide if they want a fast but less reliable check (jvp) or a slow but more reliable one (full jacobian).
scipy.optimize.check_grad does this (although it may only handle scalar outputs). In my experience it is painfully slow for all, but the simplest cases. It could be nice for debugging / checking particular edge cases, but for general purposes I am happier checking the gradient in random directions.
I thought that custom_vjp decoration wouldn't affect primal calculation when the forward pass is not doing anything except calling the base function.
That shouldn't be too hard!
I'm going to try out the fix @nsavinov suggests and report back.
I didn't follow up on this, and forgot about it until now! Un-assigning myself, though this is still worth fixing.