Pytorch-lightning: Calculating gradient during validation

Created on 5 Sep 2019  Â·  7Comments  Â·  Source: PyTorchLightning/pytorch-lightning

Hello,
I am using pytorch-lightning in a physics-related project, where I need to calculate gradients during validation. At the moment the gradient calculation is disabled in the validate function of the trainer. Of course, commenting out the line solves the problem. However, it took some time to figure out where everything goes wrong.
Solution:
I would suggest adding another parameter to trainer (e.g. enable_grad_during_validation) to allow the enabling of gradient calculation during validation. Of course, this parameter should be set to False by default so that nothing changes for users that do not require the feature. The changes that are required would be to add the parameter and change the line where the gradient is disabled.

   # disable gradients to save memory
   torch.set_grad_enabled(enable_grad_during_validation)

This might be a niche issue, therefore if no one else needs this change, I would suggest adding an extra note in the documentation of the validation loop, that informs users that gradient calculation is disabled during validation.
Ps: thank you for the great library

enhancement help wanted

Most helpful comment

Important for meta-learning and nested optimization research:

Having this feature could be quite useful for researchers working on meta learning and nested optimization.

For example, without the option to enable gradients during validation, the recent inner loop optimization library _higher_ from Facebook AI is incompatible with Pytorch Lightning.

Are there any negative downstream effects of enabling gradients during validation that I might be missing? If there aren't any, then addressing this issue by just adding a new argument to the Trainer class seems reasonable to me. I'd be happy to take a stab at it if the maintainers are ok with adding this feature.

Thanks!

All 7 comments

it's standard to freeze model during validation. So, this is an edge case. However, I think adding to the docs would be helpful.

want to take a stab at it?

Important for meta-learning and nested optimization research:

Having this feature could be quite useful for researchers working on meta learning and nested optimization.

For example, without the option to enable gradients during validation, the recent inner loop optimization library _higher_ from Facebook AI is incompatible with Pytorch Lightning.

Are there any negative downstream effects of enabling gradients during validation that I might be missing? If there aren't any, then addressing this issue by just adding a new argument to the Trainer class seems reasonable to me. I'd be happy to take a stab at it if the maintainers are ok with adding this feature.

Thanks!

@cemanil Is there a reason why we can't do the meta-test in the training step itself and have the validation remain true validation without backpropagating through it?

Thank you for your response.

I might have misunderstood your recommendation. If we'd like to compute the performance of our bilevel model on the validation or test set, how can/should we do so in the training step? Models like Structured Prediction Energy Networks require running backpropagation as part of inference, which in turn require having gradient computation enabled.

I agree with cemanil that there are important mainstream use-cases of validation-time gradient computation. For instance, any inference tasks using MCMC sampling (e.g. for energy-based models or Bayesian inference).

why can't you just enable it again in the validation_step?

Lightning handles the major use cases, but this (edge case, or not so edge for your research haha), can just be handled like this:

def validation_step(self, batch, batch_idx):
    torch.set_grad_enabled(True)
    ...

But either way, in a week or so we can revisit this since we're finishing refactors

Thank you for your reply William!

This is indeed how I ended up enabling test-time gradient computation. I was just a bit hesitant to manually toggle flags like this, in order to avoid any unanticipated side effect. This one does seem pretty harmless, though.

Do you think just adding a sentence or two about this in the documentation should suffice, then? Or would it be cleaner to add an argument to the Trainer class?

Was this page helpful?
0 / 5 - 0 ratings

Related issues

as754770178 picture as754770178  Â·  3Comments

williamFalcon picture williamFalcon  Â·  3Comments

DavidRuhe picture DavidRuhe  Â·  3Comments

mmsamiei picture mmsamiei  Â·  3Comments

polars05 picture polars05  Â·  3Comments