training network with gatherND ops raise an error “Cannot compute gradient: gradient function not found for gatherND”
I track the code and find grad function is unimplemented
I'd like to request gradient function for gatherND
Thanks!
Here is quick implementation:
import { ENGINE } from '@tensorflow/tfjs-core/dist/engine';
function gatherND(x: tf.Tensor, indices: tf.Tensor): tf.Tensor {
const grad = (dy: tf.Tensor, saved: tf.Tensor[]) => {
return { x: () => tf.scatterND(saved[0], dy, x.shape) }
}
return ENGINE.runKernel(
(backend, save) => {
save([indices]);
return backend.gatherND(x, indices);
},
{ x },
grad
) as
tf.Tensor;
}
@li-yinan is this still an issue ? were you able to try above solution
This worked for me. I'd love to see this in the library itself. gatherND is a very useful op -- I'm using it in an implementation of Q-learning, and I don't see how I could do it without this.
@rthadur This is still an issue and the solution from @marsiancba is now out of date because ENGINE.runKernel is deprecated. The grad function is available in TF for Python: https://github.com/tensorflow/tensorflow/blob/v2.3.1/tensorflow/python/ops/array_grad.py#L691-L701
@willclarktech Engine.runKernel is not deprecated, we will actually be moving all out kernels to use Engine.runKernel instead of Engine.runKernelFunc. We do however have a new way to register gradients which should make it easier to add a custom gradient.
You can see examples of them in tfjs-core/src/gradients. Here is roughly what @marsiancba's code would look like with our new api.
// Implement the gradient
import {GatherNd} from '@tensorflow/tfjs-core';
import {GradConfig} from '@tensorflow/tfjs-core';
import {scatterND} from '@tensorflow/tfjs-core';
import {Tensor} from '@tensorflow/tfjs-core';
export const gatherNdGradConfig: GradConfig = {
kernelName: GatherNd,
inputsToSave: ['indices'],
gradFunc: (dy: Tensor, saved: Tensor[]) => {
const [indices] = saved;
return { x: () => scatterND(indices, dy, indices.shape) }
}
};
//register the gradient
import {registerGradient} from '@tensorflow/tfjs-core';
registerGradient(gatherNdGradConfig);
I'm going to reopen this as it would be a good addition to the library and should be relatively straightforward to add, the only thing needed are tests for the gradient added to gather_nd_test.ts. @marsiancba if you wanted to make a PR with your implementation and some tests, I'd be happy to review it. Else we can add it when a few free cycles pop up.
@tafsiri Sorry, yes, I mean that in order to get @marsiancba 's code working with tfjs v2.4.0 I had to use the deprecated .runKernelFunc instead of .runKernel. Thanks for reopening!
@tafsiri please add it.