Hi!, I'm currently trying to train a GAN with some regularization and I would need to perform multiple optimizer steps with different losses in one single training step.
My intuition tells me to define the optimizer_step function, but I keep finding errors and I don't know what I am doing wrong.
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operationRuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first timedef training_step(self, batch, batch_idx, optimizer_idx):
real_img = batch
batch_size = rela_img.shape[0]
z = torch.randn(batch_size, self.latent_dim).to(self.device)
# Train Discriminator
if optimizer_idx == 0:
fake_img = self.generator(z)
fake_pred = self.discriminator(fake_img)
real_pred = self.discriminator(real_img)
d_loss = losses.d_loss(real_pred, fake_pred)
if batch_idx % self.args.d_regularize_every == 0:
real_img.requires_grad = True
real_pred = self.discriminator(real_img)
self.d_reg_loss = losses.d_reg_loss(real_pred, real_img)
return {'loss': d_loss}
# Train Generator
if optimizer_idx == 1:
fake_img = self.generator(z)
fake_pred = self.discriminator(fake_img)
g_loss = losses.g_loss(fake_pred)
if batch_idx % self.args.g_regularize_every == 0:
self.g_reg_loss = losses.g_reg_loss(fake_img, z)
return {'loss':g_loss}
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, *args, **kwargs):
# Step using d_loss or g_loss
super().optimizer_step(epoch, batch_idx, optimizer, optimizer_idx, *args, **kwargs)
if optimizer_idx == 0:
self.discriminator.zero_grad()
self.d_reg_loss.backward()
super().optimizer_step(epoch, batch_idx, optimizer, optimizer_idx, *args, **kwargs)
if optimizer_idx == 1:
self.generator.zero_grad()
self.g_reg_loss.backward()
super().optimizer_step(epoch, batch_idx, optimizer, optimizer_idx, *args, **kwargs)
I have tried to directly return the regularization losses instead of the normal losses. But then some at some steps instead of optimizing using the logistic loss I only optimize using the regularization loss. While I would like to always optimize at least one time and then optimize again at some steps.
Is it possible with lightning? Or I am doing something wrong?
Hi! thanks for your contribution!, great first issue!
During backward intermediate results are deleted to reduce memory usage. Here it is called twice because actual loss returned in training_step function is going to backward function in LightningModule, where actual loss.backward() is happening, and then it is called again in optimizer_step explicitly in your code. due to which the intermediate results get deleted and you are getting the error above.
Try:
def training_step(self, batch, batch_idx, optimizer_idx):
real_img = batch
batch_size = rela_img.shape[0]
z = torch.randn(batch_size, self.latent_dim).to(self.device)
# Train Discriminator
if optimizer_idx == 0:
fake_img = self.generator(z)
fake_pred = self.discriminator(fake_img) # use self.discriminator(fake_img.detach()) if you don't want to update generator weights here
real_pred = self.discriminator(real_img)
d_loss = losses.d_loss(real_pred, fake_pred)
if batch_idx % self.args.d_regularize_every == 0:
real_img.requires_grad = True
real_pred = self.discriminator(real_img)
self.d_reg_loss = losses.d_reg_loss(real_pred, real_img)
else:
self.d_reg_loss = None
return {'loss': d_loss}
# Train Generator
if optimizer_idx == 1:
fake_img = self.generator(z)
fake_pred = self.discriminator(fake_img)
g_loss = losses.g_loss(fake_pred)
if batch_idx % self.args.g_regularize_every == 0:
self.g_reg_loss = losses.g_reg_loss(fake_img, z)
else:
self.g_reg_loss = None
return {'loss':g_loss}
def backward(self, trainer, loss, optimizer, optimizer_idx):
if optimizer_idx == 0:
loss.backward()
elif optimizer_idx == 1:
loss.backward(retain_graph=True)
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, *args, **kwargs):
# Step using d_loss or g_loss
if optimizer_idx == 0:
optimizer.step()
if self.d_reg_loss is not None:
optimizer.zero_grad()
self.d_reg_loss.backward()
optimizer.step()
elif optimizer_idx == 1:
optimizer.step()
if self.g_reg_loss is not None:
optimizer.zero_grad()
self.g_reg_loss.backward()
optimizer.step()
Thank you for the response. Now it makes sense although the error keep appearing. I think it might be for something on my code internally.
Just one doubt from seeing your code,
optimizer.zero_grad() or discriminator.zero_grad()?
If you can share a colab notebook for the same I can try resolve this issue. Also if your optimizer has the parameters of the model then both optimizer.zero_grad() and model.zero_grad() does the same thing.
Well, I could not find the exact point where the backward()function fails to compute the gradients, either when I use the suggested with torch.autograd.set_detect_anomaly(True): from the error
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation.But I managed to make it work by following the code snippet below (I don't know if it is achieving what I want, but the code works now).
It makes sense for me at least, when the regularization is not needed the loss returned is the previous one but in that case no backward or optimizer_step is performed. The idea is to pass 4 optimizers, which in fact there are only 2. And then define the backward and optimizer_step functions as you suggested to use the wanted loss and wanted optimizer in the right iterations.
def configure_optimizers(self):
self.g_optim = optim.Adam(
self.generator.parameters(),
lr=self.args.lr * g_reg_ratio,
betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
)
self.d_optim = optim.Adam(
self.discriminator.parameters(),
lr=self.args.lr * d_reg_ratio,
betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
)
return self.d_optim, self.g_optim, self.d_optim, self.g_optim
def on_fit_start(self):
# Initialize Variables
self.d_reg = False
self.g_reg = False
def training_step(self, batch, batch_idx, optimizer_idx):
# Train Discriminator
if optimizer_idx == 0:
fake_img = self.generator(z)
fake_pred = self.discriminator(fake_img.detach())
real_pred = self.discriminator(real_img)
d_loss = losses.d_loss(real_pred, fake_pred)
return {'loss': d_loss}
# Discriminator Regularization
elif optimizer_idx == 1:
if batch_idx % self.args.d_regularize_every == 0:
self.d_reg = True
real_img.requires_grad = True
real_pred = self.discriminator(real_img)
self.d_reg_loss = losses.d_reg_loss(real_pred, real_img)
return {'loss': self.d_reg_loss}
return {'loss': self.d_reg_loss}
# Train Generator
elif optimizer_idx == 2:
fake_img = self.generator(z)
fake_pred = self.discriminator(fake_img)
g_loss = losses.g_loss(fake_pred)
return {'loss':g_loss}
# Generator Regularization
elif optimizer_idx == 3:
if batch_idx % self.args.g_regularize_every == 0:
self.g_reg = True
self.g_reg_loss = losses.g_reg_loss(fake_img, z)
return {'loss':self.g_reg_loss}
return {'loss':self.g_reg_loss}
def backward(self, trainer, loss, optimizer, optimizer_idx):
if optimizer_idx == 0 or optimizer_idx == 2:
super().backward(trainer, loss, optimizer, optimizer_idx)
if optimizer_idx == 1 or optimizer_idx == 3:
if self.d_reg:
super().backward(trainer, loss, optimizer, optimizer_idx)
if self.g_reg:
super().backward(trainer, loss, optimizer, optimizer_idx)
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, *args, **kwargs):
if optimizer_idx == 0:
self.d_optim.step()
if optimizer_idx == 1 and self.d_reg:
self.d_optim.step()
self.d_reg = False
if optimizer_idx == 2:
self.g_optim.step()
if optimizer_idx == 3 and self.g_reg:
self.g_optim.step()
self.g_reg = False
Hope that it's helpful for someone training a gan with multiple steps in one training_step or batch.
Thanks a lot @rohitgr7!!
Nice! @paudom, but I still think there is a way to make it work with just two optimizers. If it's possible to share a colab notebook for the same, would be really helpful :)
The code I'm using is from a private corporation which I'm part of. Therefore I cannot provide the exact code.
But I can work on making a test version in colab so we can discuss how to make it work with only 2 optimizers.
Hi again, sorry for the delay. I've made a colab with a dummy example using MNIST to train a GAN. I've tried to maintain as much as possible the original code. I show the three options we have discussed. I still don't understand why your option does not work.
Here is the link: https://colab.research.google.com/drive/1Z2B9nd5jUJDbOFjwegZrT8SD7ZSyH9W_?usp=sharing
If for some reason is not available, tell me so I can check the correct link
@paudom I got it why this is happening. Let's say initially your discriminator has W1 weights. While calculating d_loss you are using the discriminator so after the first d_loss.backward() and optimizer0.step() discriminator gets updated and weights become W2. Now while calculating d_reg_loss in the first step itself it's using the old discriminator weights W1 because by that time weights are not updated but while doing d_reg_loss.backward() the gradients during backprop are calculated using new weights W2 since it's happening after the first optimizer0.step() with d_loss. That's why it throwing this error because forward pass happened with W1 but backward is happening with W2. The reason it's working with different optimizers is that the discriminator weights used there within optimizer_idx=2 are the updated ones.
I think the workflow is something like:
Similarly for the second one...
So I guess using 4 optimizers which are actually 2 is the right way here :)
Hi @rohitgr7 thanks a lot for the explanation. I thought that when calling super().optimizer_step() the weights are updated everywhere so that could not happen, but makes sense that gradients in the backward are saved using the weights of the forward pass.
For the moment I close the issue. Although it will be interesting to find a way to do this in lightning, defining only 2 optimizers.
Thanks a lot for your suggestions!
Since my question is related, I'm gonna hijack the thread
I have a model with n heads and n+1 optimizers (one for each head and one for the shared backbone).
I basically want to do one pass through the model, calculate the gradients for the n+1 different parts
of the model and then update them together (if I don't do it together, pytorch complains that the weights
change between consecutive 'backward'-calls)
This is what is approximately looks like. It works, I am just not sure though if pytorch lightning does some
magic under the hood that I am invalidating. Would be nice if someone with a bit more knowledge could have
a look.
def training_step(self, batch: dict, batch_nb: int, optimizer_idx: int) -> pl.TrainResult:
if optimizer_idx == 0:
# gonna run the complete model: backbone + n heads.
# loss is a sum of all losses
self.saved_loss = ...
else:
res = {'loss': self.saved_loss}
return res
def backward(self, trainer, loss, optimizer, optimizer_idx):
# keep the graph until the last optimizer
if optimizer_idx < len(self.opimizers) - 1:
loss.backward(retain_graph=True)
else:
loss.backward()
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False,
using_native_amp=False, using_lbfgs=False):
if optimizer_idx != len(self.task_heads):
return
for optimizer_idx, optimizer in enumerate(self.optimizers):
# enable the weights of the current optimizer & disable the rest
for param in self.parameters():
param.requires_grad = False
for group in optimizer.param_groups:
for param in group['params']:
param.requires_grad = True
optimizer.step()
optimizer.zero_grad()
BTW if anyone interested you can checkout the latest update regarding optimization. You can do manual optimization now using a automatic_optimization flag in the Trainer.
https://pytorch-lightning.readthedocs.io/en/latest/optimizers.html#manual-optimization
thank you! yeah, turns out my approach does not work. pytorch lightning for instance sets the gradients to zero at some other time. But also when doing that, the training seems not to work (it runs, but it doesn't really improve). The new api should solve the problem however! Really exiting that version 1.0 is out now :)