Jax: Scaling behaviour of jax.test_util.check_grads

Created on 18 May 2020  路  7Comments  路  Source: google/jax

It seems like check_grads has incorrect scaling behaviour for large input tensors.

In the code of check_jvp (which is called by check_grads), the following line essentially performs a weighted-sum reduction over args.size values:

v_out, t_out = f_jvp(args, tangent)

The vector tangent has a length which grows with the dimension of args. This would warrant scaling of tolerances by args.size. However, such scaling is not performed. Instead, t_out is passed into

check_close(t_out, t_out_expected, atol=atol, rtol=rtol)

which will eventually call:

_assert_numpy_allclose(a, b, atol=atol * a.size, rtol=rtol * b.size)

Thus, the scaling will happen by t_out.size which is the height of the Jacobian, instead of the width args.size, along which the reduction actually happened. Moreover, _assert_numpy_allclose actually performs element-wise comparison with the given tolerances and does

cond = reduced.all()

reduction on a boolean array. So no need for tolerance scaling in this function. If tolerance scaling is necessary, it should happen before, not here.

This behaviour looks incorrect to me. Am I missing something?

available bug

All 7 comments

Your analysis sounds right to me. By not scaling tolerances this way, check_grads is being too strict for functions with large input dimension; is that right?

What fix do you recommend?

I forgot to ask the main question I was wondering: is the scaling by the height of the Jacobian in _assert_numpy_allclose causing us to be effectively too loose with some tolerances? For example, is there a test that fails with some given atol/rtol setting, but starts passing if we just broadcast the output to have a taller Jacobian?

Both of your conclusions are true:

  1. For very wide Jacobians, the test is incorrectly strict and will fail without reason. This is the case of large input tensors and scalar output. This happened to me and I started digging. :)
  1. For very high Jacobians, the test is incorrectly loose. This is the case of artificial output broadcasting that you mentioned.

If we want a minimal fix, we could swap scaling/not scaling for check_jvp like this:

check_close(t_out, t_out_expected, atol=atol * args.size, rtol=rtol * args.size)
_assert_numpy_allclose(a, b, atol=atol, rtol=rtol)



md5-a6a25649c8a52510457e13d05deabd1b



Assume a == b within atol, rtol:
abs(a - b) = atol + rtol * b ("just within")
Now if we take a1 = 100 * a, b1 = 100 * b:
abs(a1 - b1) = 100 * atol + rtol * 100 * b = 100 * atol + rtol * b1 >> atol + rtol * b1
  1. Is linear scaling the right way to do it? If it is, maybe still worth a comment in the source code so that readers can follow through?

  2. Maybe worth to clarify for the users what kind of functions check_grads expects? I.e. vector/scalar input, vector/scalar output?

Finally, more a feature request than a fix: how hard would it be to add a full Jacobian checking? Then people could decide if they want a fast but less reliable check (jvp) or a slow but more reliable one (full jacobian).

Finally, more a feature request than a fix: how hard would it be to add a full Jacobian checking? Then people could decide if they want a fast but less reliable check (jvp) or a slow but more reliable one (full jacobian).

scipy.optimize.check_grad does this (although it may only handle scalar outputs). In my experience it is painfully slow for all, but the simplest cases. It could be nice for debugging / checking particular edge cases, but for general purposes I am happier checking the gradient in random directions.

I thought that custom_vjp decoration wouldn't affect primal calculation when the forward pass is not doing anything except calling the base function.

That shouldn't be too hard!

I'm going to try out the fix @nsavinov suggests and report back.

I didn't follow up on this, and forgot about it until now! Un-assigning myself, though this is still worth fixing.

Was this page helpful?
0 / 5 - 0 ratings