I have a loss module that is loaded as part of my lightning module with its own inner network. (output is passed through the network and the result is used to compute the loss)
The problem is that when starting a train_step Lightning automatically changes the entire module to .train() which recursively changes the loss inner network as well. As the inner network has batch normalization layers this effectively makes the loss criteria change during training which is of course not desired.
Is there a right way to incorporate such a loss function?
Hi! thanks for your contribution!, great first issue!
Can you give me an example of what this looks like? Perhaps in a colab notebook?
Sure, the full code is a bit involved, but here is a basic example
class FeatureLoss(nn.Module):
def __init__(self, ckpt_path):
super(FeatureLoss, self).__init__()
self.facenet = LossArch()
self.facenet.load_state_dict(torch.load(ckpt_path))
def forward(self, y_hat, y):
y_hat_feats = self.facenet(y_hat)
y_feats = self.facenet(y)
return F.mse_loss(y_hat_feats, y_feats)
class CoreNet(pl.LightningModule):
def __init__(self, hparams):
super(CoreNet, self).__init__()
# Init Network
self.net = MyArch()
# Load loss
self.feat_loss = FeatureLoss('path_to_weights.ckpt')
def forward(x):
return self.net(x)
def training_step(self, batch):
x, y = batch
y_hat = self(x) #.net is in .train() mode
loss = self.feat_loss(y, y_hat) # loss is also in .train() mode
Two options on top of my mind:
.train() will have no effect. Easiest way to do this is to wrap in a list: self.feat_loss = [FeatureLoss('path_to_weights.ckpt')].train method of FeatureLoss so it does nothing. When CoreNet.train() is called it recursively call the .train() method of all its child modules, so if FeatureLoss.train() does nothing then it would not change.Nice ideas, what I'm currently doing is to manually change self.feat_loss to .eval() in each training_step, so these would probably be more elegant.
I think that 1. is probably a better way to go because it also stops lightning from saving the weights of the loss in each checkpoint which also bothered me.
However, it still feels a bit hacky and I think that especially for new practitioners it would be hard to figure out that lightning might change their loss modules to .train() mode, as the LightningModule __init__ still seems like the place the loss initialization needs to go in lightning code structure.
I agree that it may feel a bit hacky, but I would say that it has nothing to do with lightning framework and more to do with native pytorch. You would have the same problem whether or not your CoreNet is a LightningModule or a nn.Module (that calling .train on the outermost module will alter all submodules)
I don't totally agree. In native PyTorch I will never put the loss module as part of my network. It would be a different object that exists in the main training loop context and therefore I would be able to pass the network from .train to .eval without affecting it.
Unless I'm missing something in Lightning's structure, I don't have a different location to place the loss module as it has to be accessed in the training_step, and this is what's problematic here. So the problem here, in my opinion, is caused by the fact that the trainable network is the same object that defines the training loop and loss.
Fair point. You are not missing something in Lightnings struture, and I think your problem is very application specific. I am not familiar with this approach where you want to keep some of the trainable network in eval mode during training.
This maybe hints that you could just subclass your normalization layers such that they do not change when .train() is called.
Basically it might happen whenever a loss function is based on a trainable network, for example when using facial identity loss or perceptual loss, which has become quite common for tasks involving image generation, similar to what's done in this paper
I think that the proposed solutions are helpful, it just that I believe it's a common enough scenario to have some form of support. For example, by defining somewhere to initialize these objects that do not need to be trained.
From my experience, the fact that lightning handles train/eval mode for you (which is great), made me miss the problems with integrating the aforementioned loss function and it took me awhile track it down, which is probably not the experience we want users to have.
Even pointing this behavior out in the documentation could be helpful in my opinion.
Alright, I better understand the setting now. I think the compromise for now is to see if more people are interested in such a feature and then take it from there.
Our documentation can always be made clearer or extended (it is implicitly shown in the figure on this page https://pytorch-lightning.readthedocs.io/en/latest/trainer.html that model.train() is called inside trainer) so if you have a proposed way to improve this feel free to submit a PR.
Agree with @eladrich that this is an important feature.
I use perceptual loss which depends on a network.
a simple solution
class FeaturesNet(pl.LightningModule):
def __init__(self, weights_path):
super().__init__()
self.model = MobileFaceNet(512)
self.model.load_state_dict(torch.load(weights_path))
def setup(self, device: torch.device):
self.freeze()
def train(self, mode: bool):
""" avoid pytorch lighting auto set trian mode """
return super().train(False)
def state_dict(self, destination, prefix, keep_vars):
""" avoid pytorch lighting auto save params """
destination = OrderedDict()
destination._metadata = OrderedDict()
return destination
@SkafteNicki I think no need to add this feature in PyTorch lighting but need add more detail in transfer learning document. When I training a GAN model, use the pre-trained module for content loss, but the document didn't say the model will automatically switch to training mode.
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!
Most helpful comment
Alright, I better understand the setting now. I think the compromise for now is to see if more people are interested in such a feature and then take it from there.
Our documentation can always be made clearer or extended (it is implicitly shown in the figure on this page https://pytorch-lightning.readthedocs.io/en/latest/trainer.html that
model.train()is called inside trainer) so if you have a proposed way to improve this feel free to submit a PR.