Pytorch-lightning: State maintenance in DP

Created on 2 Dec 2019  路  7Comments  路  Source: PyTorchLightning/pytorch-lightning

In many image generation tasks with GANs, generator and discriminator is trained through the same generated image single iteration.
In PyTorch Lightning, the procedure is written like below:

def training_step(self, batch, batch_nb, optimizer_i):
    foo = batch['foo']
    bar = batch['bar']

    if optimizer_i == 0:  # train discriminator
        self.foo_out = self.netG(foo)  # register as a instance variable

        # calc d_loss
        d_loss = ...

        return {'loss': d_loss}

    elif optimizer_i == 1:  # train generator
        # common reconstruction error
        g_loss = F.l1_loss(self.foo_out, bar)
        # other losses
        ...

        return {'loss': g_loss}

It works well on single GPU, however, self.foo_out has been flushed in optimizer_i == 1 branch when DP is set.

I think it is a undesired behavior, any help or fix?

DP enhancement help wanted

Most helpful comment

DP replicates the source module for every call to forward. If you want to maintain state, you can't do this and rather should replicate once and then broadcast parameters and buffers from module[0] to the others. See torch/nn/parallel/{data_parallel,replicate}.py for more details. You'll see a section that broadcasts and sets the parameters/buffers. That's what still needs to be done for every iteration. The part that runs _replicate_for_data_parallel is what you'd want to skip.

All 7 comments

@S-aiueo32 yeah, this is a limitation of PyTorch. I've been looking at how to maintain state when using DP but there seems to be no clear way...

@pietern I think we talked about this a few months ago. Any suggestions on how to maintain state when using DP?

DP replicates the source module for every call to forward. If you want to maintain state, you can't do this and rather should replicate once and then broadcast parameters and buffers from module[0] to the others. See torch/nn/parallel/{data_parallel,replicate}.py for more details. You'll see a section that broadcasts and sets the parameters/buffers. That's what still needs to be done for every iteration. The part that runs _replicate_for_data_parallel is what you'd want to skip.

@williamFalcon @pietern
Thank you for the polite explanation.
I understood the limitations and that it is not avoidable as long as LightningModule inherits nn.Module.

actually, it should be avoidable given the explanation above. we just need to make the appropriate changes to the dp subclass

This should be a companion class to nn.DataParallel. I don't want to change the behavior of the existing wrapper because I'm sure folks depend on replicating the model on every call to forward. It shouldn't be too hard though, and can use nn.DataParallel as a starting point.

Just wanted to check if there was any update/advice on this type of issue? I've got a similar situation with a GAN producing images in the first optimizer iteration then using them to update the discriminator in the second. It works well on a single GPU, but when distributing I run into the same issue. I initially thought adding the property as a buffer would maintain it, but it seems to be flushed when using DP in the same way. Is the only solution to run the generator in the discriminator's optimizer iteration?

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

srush picture srush  路  3Comments

jcreinhold picture jcreinhold  路  3Comments

monney picture monney  路  3Comments

anthonytec2 picture anthonytec2  路  3Comments

Vichoko picture Vichoko  路  3Comments