Pytorch-lightning: Do multiple optimizer steps in one training step

Created on 4 Aug 2020  路  13Comments  路  Source: PyTorchLightning/pytorch-lightning

What is your question?

Hi!, I'm currently trying to train a GAN with some regularization and I would need to perform multiple optimizer steps with different losses in one single training step.

My intuition tells me to define the optimizer_step function, but I keep finding errors and I don't know what I am doing wrong.

  • RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
  • RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time

Code

def training_step(self, batch, batch_idx, optimizer_idx):
    real_img = batch
    batch_size = rela_img.shape[0]
    z = torch.randn(batch_size, self.latent_dim).to(self.device)

    # Train Discriminator 
    if optimizer_idx == 0:
        fake_img = self.generator(z)
        fake_pred = self.discriminator(fake_img)
        real_pred = self.discriminator(real_img)
        d_loss = losses.d_loss(real_pred, fake_pred)
        if batch_idx % self.args.d_regularize_every == 0:
            real_img.requires_grad = True
            real_pred = self.discriminator(real_img)
            self.d_reg_loss = losses.d_reg_loss(real_pred, real_img)
        return {'loss': d_loss}

    # Train Generator
    if optimizer_idx == 1:
        fake_img = self.generator(z)
        fake_pred = self.discriminator(fake_img)
        g_loss = losses.g_loss(fake_pred)
        if batch_idx % self.args.g_regularize_every == 0:
             self.g_reg_loss = losses.g_reg_loss(fake_img, z)
        return {'loss':g_loss}

def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, *args, **kwargs):
     # Step using d_loss or g_loss
     super().optimizer_step(epoch, batch_idx, optimizer, optimizer_idx, *args, **kwargs)
     if optimizer_idx == 0:
         self.discriminator.zero_grad()
         self.d_reg_loss.backward()
         super().optimizer_step(epoch, batch_idx, optimizer, optimizer_idx, *args, **kwargs)
     if optimizer_idx == 1:
         self.generator.zero_grad()
         self.g_reg_loss.backward()
         super().optimizer_step(epoch, batch_idx, optimizer, optimizer_idx, *args, **kwargs)

What have you tried?

I have tried to directly return the regularization losses instead of the normal losses. But then some at some steps instead of optimizing using the logistic loss I only optimize using the regularization loss. While I would like to always optimize at least one time and then optimize again at some steps.

Is it possible with lightning? Or I am doing something wrong?

What's your environment?

  • OS: MacOS
  • Packaging: pip
  • Version: 0.8.4
question

All 13 comments

Hi! thanks for your contribution!, great first issue!

During backward intermediate results are deleted to reduce memory usage. Here it is called twice because actual loss returned in training_step function is going to backward function in LightningModule, where actual loss.backward() is happening, and then it is called again in optimizer_step explicitly in your code. due to which the intermediate results get deleted and you are getting the error above.

Try:

def training_step(self, batch, batch_idx, optimizer_idx):
    real_img = batch
    batch_size = rela_img.shape[0]
    z = torch.randn(batch_size, self.latent_dim).to(self.device)

    # Train Discriminator 
    if optimizer_idx == 0:
        fake_img = self.generator(z)
        fake_pred = self.discriminator(fake_img) # use self.discriminator(fake_img.detach()) if you don't want to update generator weights here
        real_pred = self.discriminator(real_img)
        d_loss = losses.d_loss(real_pred, fake_pred)
        if batch_idx % self.args.d_regularize_every == 0:
            real_img.requires_grad = True
            real_pred = self.discriminator(real_img)
            self.d_reg_loss = losses.d_reg_loss(real_pred, real_img)
        else:
            self.d_reg_loss = None
        return {'loss': d_loss}

    # Train Generator
    if optimizer_idx == 1:
        fake_img = self.generator(z)
        fake_pred = self.discriminator(fake_img)
        g_loss = losses.g_loss(fake_pred)
        if batch_idx % self.args.g_regularize_every == 0:
             self.g_reg_loss = losses.g_reg_loss(fake_img, z)
        else:
            self.g_reg_loss = None
        return {'loss':g_loss}

def backward(self, trainer, loss, optimizer, optimizer_idx):
    if optimizer_idx == 0:
        loss.backward()
    elif optimizer_idx == 1:
        loss.backward(retain_graph=True)

def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, *args, **kwargs):
    # Step using d_loss or g_loss
    if optimizer_idx == 0:
        optimizer.step()
        if self.d_reg_loss is not None:
            optimizer.zero_grad()
            self.d_reg_loss.backward()
            optimizer.step()
    elif optimizer_idx == 1:
        optimizer.step()
        if self.g_reg_loss is not None:
            optimizer.zero_grad()
            self.g_reg_loss.backward()
            optimizer.step()

Thank you for the response. Now it makes sense although the error keep appearing. I think it might be for something on my code internally.

Just one doubt from seeing your code,
optimizer.zero_grad() or discriminator.zero_grad()?

If you can share a colab notebook for the same I can try resolve this issue. Also if your optimizer has the parameters of the model then both optimizer.zero_grad() and model.zero_grad() does the same thing.

Well, I could not find the exact point where the backward()function fails to compute the gradients, either when I use the suggested with torch.autograd.set_detect_anomaly(True): from the error

  • RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation.

But I managed to make it work by following the code snippet below (I don't know if it is achieving what I want, but the code works now).

It makes sense for me at least, when the regularization is not needed the loss returned is the previous one but in that case no backward or optimizer_step is performed. The idea is to pass 4 optimizers, which in fact there are only 2. And then define the backward and optimizer_step functions as you suggested to use the wanted loss and wanted optimizer in the right iterations.

def configure_optimizers(self):
    self.g_optim = optim.Adam(
        self.generator.parameters(),
        lr=self.args.lr * g_reg_ratio,
        betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
        )
    self.d_optim = optim.Adam(
        self.discriminator.parameters(),
        lr=self.args.lr * d_reg_ratio,
        betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
        )
    return self.d_optim, self.g_optim, self.d_optim, self.g_optim

def on_fit_start(self):
    # Initialize Variables
    self.d_reg = False
    self.g_reg = False

def training_step(self, batch, batch_idx, optimizer_idx):
    # Train Discriminator 
    if optimizer_idx == 0:
        fake_img = self.generator(z)
        fake_pred = self.discriminator(fake_img.detach())
        real_pred = self.discriminator(real_img)
        d_loss = losses.d_loss(real_pred, fake_pred)
        return {'loss': d_loss}

    # Discriminator Regularization
    elif optimizer_idx == 1:
        if batch_idx % self.args.d_regularize_every == 0:
            self.d_reg = True
            real_img.requires_grad = True
            real_pred = self.discriminator(real_img)
            self.d_reg_loss = losses.d_reg_loss(real_pred, real_img)
            return {'loss': self.d_reg_loss}
        return {'loss': self.d_reg_loss}

    # Train Generator
    elif optimizer_idx == 2:
        fake_img = self.generator(z)
        fake_pred = self.discriminator(fake_img)
        g_loss = losses.g_loss(fake_pred)
        return {'loss':g_loss}

    # Generator Regularization
    elif optimizer_idx == 3:
        if batch_idx % self.args.g_regularize_every == 0:
            self.g_reg = True
            self.g_reg_loss = losses.g_reg_loss(fake_img, z)
            return {'loss':self.g_reg_loss}
        return {'loss':self.g_reg_loss}

def backward(self, trainer, loss, optimizer, optimizer_idx):
    if optimizer_idx == 0 or optimizer_idx == 2: 
        super().backward(trainer, loss, optimizer, optimizer_idx)
    if optimizer_idx == 1 or optimizer_idx == 3:
        if self.d_reg:
            super().backward(trainer, loss, optimizer, optimizer_idx)
        if self.g_reg:
            super().backward(trainer, loss, optimizer, optimizer_idx)

def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, *args, **kwargs):
    if optimizer_idx == 0: 
        self.d_optim.step()
    if optimizer_idx == 1 and self.d_reg:
        self.d_optim.step()
        self.d_reg = False
    if optimizer_idx == 2:
        self.g_optim.step()
    if optimizer_idx == 3 and self.g_reg:
        self.g_optim.step()
        self.g_reg = False

Hope that it's helpful for someone training a gan with multiple steps in one training_step or batch.
Thanks a lot @rohitgr7!!

Nice! @paudom, but I still think there is a way to make it work with just two optimizers. If it's possible to share a colab notebook for the same, would be really helpful :)

The code I'm using is from a private corporation which I'm part of. Therefore I cannot provide the exact code.
But I can work on making a test version in colab so we can discuss how to make it work with only 2 optimizers.

Hi again, sorry for the delay. I've made a colab with a dummy example using MNIST to train a GAN. I've tried to maintain as much as possible the original code. I show the three options we have discussed. I still don't understand why your option does not work.

Here is the link: https://colab.research.google.com/drive/1Z2B9nd5jUJDbOFjwegZrT8SD7ZSyH9W_?usp=sharing

If for some reason is not available, tell me so I can check the correct link

@paudom I got it why this is happening. Let's say initially your discriminator has W1 weights. While calculating d_loss you are using the discriminator so after the first d_loss.backward() and optimizer0.step() discriminator gets updated and weights become W2. Now while calculating d_reg_loss in the first step itself it's using the old discriminator weights W1 because by that time weights are not updated but while doing d_reg_loss.backward() the gradients during backprop are calculated using new weights W2 since it's happening after the first optimizer0.step() with d_loss. That's why it throwing this error because forward pass happened with W1 but backward is happening with W2. The reason it's working with different optimizers is that the discriminator weights used there within optimizer_idx=2 are the updated ones.

I think the workflow is something like:

  1. d_loss = # calculate loss1 using discriminator
  2. d_loss.backward()
  3. optimizer1.step()
  4. optimizer1.zero_grad()
  5. d_reg_loss = # calculate using updated discriminator from step 4
  6. d_reg_loss.backward()
  7. optimizer1.step()
  8. optimizer1.zero_grad()

Similarly for the second one...

So I guess using 4 optimizers which are actually 2 is the right way here :)

Hi @rohitgr7 thanks a lot for the explanation. I thought that when calling super().optimizer_step() the weights are updated everywhere so that could not happen, but makes sense that gradients in the backward are saved using the weights of the forward pass.

For the moment I close the issue. Although it will be interesting to find a way to do this in lightning, defining only 2 optimizers.

Thanks a lot for your suggestions!

Since my question is related, I'm gonna hijack the thread

I have a model with n heads and n+1 optimizers (one for each head and one for the shared backbone).
I basically want to do one pass through the model, calculate the gradients for the n+1 different parts
of the model and then update them together (if I don't do it together, pytorch complains that the weights
change between consecutive 'backward'-calls)

This is what is approximately looks like. It works, I am just not sure though if pytorch lightning does some
magic under the hood that I am invalidating. Would be nice if someone with a bit more knowledge could have
a look.

def training_step(self, batch: dict, batch_nb: int, optimizer_idx: int) -> pl.TrainResult:

    if optimizer_idx == 0:

    # gonna run the complete model: backbone + n heads.
    # loss is a sum of all losses
        self.saved_loss = ...

    else:
        res = {'loss': self.saved_loss}

    return res

def backward(self, trainer, loss, optimizer, optimizer_idx):

    # keep the graph until the last optimizer
    if optimizer_idx < len(self.opimizers) - 1:
        loss.backward(retain_graph=True)
    else:
        loss.backward()

def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False,
                   using_native_amp=False, using_lbfgs=False):

    if optimizer_idx != len(self.task_heads):
        return

    for optimizer_idx, optimizer in enumerate(self.optimizers):

        # enable the weights of the current optimizer & disable the rest
        for param in self.parameters():
            param.requires_grad = False
        for group in optimizer.param_groups:
            for param in group['params']:
                param.requires_grad = True

        optimizer.step()
        optimizer.zero_grad()

BTW if anyone interested you can checkout the latest update regarding optimization. You can do manual optimization now using a automatic_optimization flag in the Trainer.
https://pytorch-lightning.readthedocs.io/en/latest/optimizers.html#manual-optimization

thank you! yeah, turns out my approach does not work. pytorch lightning for instance sets the gradients to zero at some other time. But also when doing that, the training seems not to work (it runs, but it doesn't really improve). The new api should solve the problem however! Really exiting that version 1.0 is out now :)

Was this page helpful?
0 / 5 - 0 ratings

Related issues

srush picture srush  路  3Comments

anthonytec2 picture anthonytec2  路  3Comments

monney picture monney  路  3Comments

iakremnev picture iakremnev  路  3Comments

versatran01 picture versatran01  路  3Comments