Pytorch-lightning: GAN example: Only one backward() call?

Created on 5 Dec 2019  路  7Comments  路  Source: PyTorchLightning/pytorch-lightning

In the PyTorch GAN tutorial https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html there are two backward() calls for the discriminator. How do you ensure this with your structure, where backward() gets called after the training step?

Best,
Alain

enhancement let's do it! question

Most helpful comment

I am adding here that the official implementation of WGAN-GP (for example) needs this feature in order to converge.
This is a very fundamental feature in GAN training.
image

All 7 comments

good catch. we don't actually support this atm.

A solution here is to allow a dictionary of options for each optimizer which allows arbitrary number of calls.
Something like:

def configure_optimizers(self,...):
    opt_G = {'optimizer': Adam(...), 'frequency': 2, 'lr_scheduler': LRScheduler(...)}
    opt_D = {'optimizer': Adam(...), 'frequency': 1, 'lr_scheduler': LRScheduler(...)}

    return opt_G, opt_D

Here G would be called twice back to back, and G once after

@jeffling @neggert

But not sure if this is a clean user experience

@williamFalcon that API would work for us.

@alainjungo: Some workarounds, none of them ideal:

  1. Skip every other generator step. You'll have to double your iterations
  2. Double learning rate on D (not algorithmically the same but can have similar effect)

You can always return one loss at the training step that captures both losses.
In the dcgan example, this would look like errD = errD_real + errD_fake; errD.backword();.

Not sure if I'm correct here, but this seems equivalent and matches PTL paradigms.

I have implemented an API that allows returning optimizer, lr_schedulers, optimizer_frequencies,
and then based on the batch_idx will determine the current optimizer to use.

If there is an agreement on this API, I'll proceed to testing, documenting and submitting a PR.

Another option would be to allow returning a tuple of dictionaries as @williamFalcon suggested. that would be a minor change for me and I am willing to that if it is agreed upon.

I am adding here that the official implementation of WGAN-GP (for example) needs this feature in order to converge.
This is a very fundamental feature in GAN training.
image

I like the @williamFalcon which seems clear to me...
@asafmanor mind sens a PR or describe what API you have in mind?
cc: @PyTorchLightning/core-contributors any comments on this?

I'll implement the @williamFalcon API and send a detailed PR over the weekend 馃憤

Was this page helpful?
0 / 5 - 0 ratings

Related issues

polars05 picture polars05  路  3Comments

williamFalcon picture williamFalcon  路  3Comments

jcreinhold picture jcreinhold  路  3Comments

maxime-louis picture maxime-louis  路  3Comments

edenlightning picture edenlightning  路  3Comments