Pytorch-lightning: Cannot pickle custom metric with DDP mode.

Created on 6 Aug 2020  ·  4Comments  ·  Source: PyTorchLightning/pytorch-lightning

🐛 Bug

To Reproduce

Steps to reproduce the behavior:

  1. just run this script.

Code sample

import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.metrics import Metric, TensorMetric

class MetricPerplexity(Metric):
    """
    Computes the perplexity of the model.
    """

    def __init__(self, pad_idx: int, *args, **kwargs):
        super().__init__(name='ppl')
        self.pad_idx = pad_idx

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

        loss = F.cross_entropy(pred, target, reduction='none')
        non_padding = target.ne(self.pad_idx)
        loss = loss.masked_select(non_padding).sum()

        num_words = non_padding.sum()
        ppl = torch.exp(
            torch.min(loss / num_words, torch.tensor([100]).type_as(loss))
        )
        return ppl

class TensorPerplexity(TensorMetric):
    """
    Computes the perplexity of the model.
    """

    def __init__(self, pad_idx: int, *args, **kwargs):
        super().__init__(name='ppl', *args, **kwargs)
        self.pad_idx = pad_idx

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

        loss = F.cross_entropy(pred, target, reduction='none')
        non_padding = target.ne(self.pad_idx)
        loss = loss.masked_select(non_padding).sum()

        num_words = non_padding.sum()
        ppl = torch.exp(
            torch.min(loss / num_words, torch.tensor([100]).type_as(loss))
        )
        return ppl

class ModelWithMetric(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(50, 1)
        self.ppl = MetricPerplexity(100)

    def training_step(self, batch, batch_nb):
        return {'loss': torch.mean(self(batch))}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), 0.01)

    def forward(self, batch):
        return self.lin(batch[0])

class ModelWithTensorMetric(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(50, 1)
        self.ppl = TensorPerplexity(100)

    def training_step(self, batch, batch_nb):
        return {'loss': torch.mean(self(batch))}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), 0.01)

    def forward(self, batch):
        return self.lin(batch[0])

if __name__ == "__main__":
    m1 = ModelWithMetric()
    m2 = ModelWithTensorMetric()

    t1 = pl.Trainer(distributed_backend='ddp_cpu', fast_dev_run=True)
    t2 = pl.Trainer(distributed_backend='ddp_cpu', fast_dev_run=True)

    loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.rand(10, 50)), batch_size=5)
    t1.fit(m1, train_dataloader=loader)
    print('works well')
    t2.fit(m2, train_dataloader=loader)

Error log

AttributeError: Can't pickle local object '_apply_to_outputs.<locals>.decorator_fn.<locals>.new_func'

Expected behavior

I do not know how to explain this behaviour, I think both of this classes should work well with DDP.

Environment

  • OS: MacOS

  • CUDA:
    - GPU:
    - available: False
    - version: None

  • Packages:
    - numpy: 1.18.5
    - pyTorch_debug: False
    - pyTorch_version: 1.6.0
    - pytorch-lightning: 0.9.0rc9
    - tensorboard: 2.2.2
    - tqdm: 4.48.0
  • System:
    - OS: Darwin
    - architecture:
    - 64bit
    - processor: i386
    - python: 3.7.7
    - version: Darwin Kernel Version 19.6.0: Sun Jul 5 00:43:10 PDT 2020

Additional context

Metrics bug / fix help wanted

All 4 comments

Hi @c00k1ez ,

we ware aware of that. This is due to our decorators. I have some local experiments on that, but I was not yet able to get it running completely. Since some of the issues may be resolved with #2528 , I'd like to wait for this one before trying to tackle it.
EDIT: basically this is due to https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/metrics/converters.py#L25 and https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/metrics/converters.py#L50.

If you want to take this over, the way to go is probably to use https://docs.python.org/3.8/library/functools.html#functools.wraps

Ok 👌
Yeah, thanks for links, I will explore this question :+1:

To check up again on this: seems like #2528 actually resolves this issue. So probably don't investigate much further until we verified this with #2528

Just confirmed that #2528 actually solved this issue (using the provided code). Closing the issue.

Was this page helpful?
0 / 5 - 0 ratings