Add an on_after_backward callback option.
Currently there are no callbacks for after the backwards step so non-essential code like logging gradients clutters the LightningModule.
Expand the callback options to include a hook for on_after_backward() to execute callbacks immediately after the backwards pass
Sticking my non-essential gradient logging code in my LightningModule
useful!
We nearly need the same feature: on_before_backward as an abstract method in the Callback base class would be awesome!
Currently we are using the same alternative (in the training_step) but this couples one last bit of our custom callback with the code of the lightning module...
maybe we should add all the optimizer-related hooks to callbacks too if it makes sense.
What do you think @awaelchli?
@rohitgr7 which hooks do you have in mind?
@FelixLorenz it shouldn't be abstract
@awaelchli all of them? on_after_backward, on_before_backward, on_before_zero_grad.
I think that would be nice to have.
okay, will add them.
Waiting for more approvals to see if these hooks are good to add or not.
@PyTorchLightning/core-contributors
Most helpful comment
okay, will add them.
Waiting for more approvals to see if these hooks are good to add or not.
@PyTorchLightning/core-contributors