Pytorch-lightning: Metrics fail on DP and multiple GPU

Created on 25 Oct 2020  路  11Comments  路  Source: PyTorchLightning/pytorch-lightning

馃悰 Bug


When using a metric such as Accuracy from pytorch_lightning.metrics in machine with 4 GPU and in 'dp' mode, there is an error due to accumulating the metric in different devices. In the case of Accuracy, in line:
https://github.com/PyTorchLightning/pytorch-lightning/blob/c8ccec7a02c53ed38af6ef7193232426384eee4a/pytorch_lightning/metrics/classification/accuracy.py#L108

The arguments in torch.sum are in the same device the metric is been called from, but the self.correct is in a different one. The traceback is as follows:

    self.accuracy_val(y_hat, y)
  File "/home/***/.conda/envs/***/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/***/.conda/envs/***/lib/python3.8/site-packages/pytorch_lightning/metrics/metric.py", line 153, in forward
    self.update(*args, **kwargs)
  File "/home/***/.conda/envs/***/lib/python3.8/site-packages/pytorch_lightning/metrics/metric.py", line 199, in wrapped_func
    return update(*args, **kwargs)
  File "/home/***/.conda/envs/***/lib/python3.8/site-packages/pytorch_lightning/metrics/classification/accuracy.py", line 109, in update
    self.correct += torch.sum(preds == target)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

Please reproduce using the BoringModel and post here


https://colab.research.google.com/drive/1zcU1ADuHZj82clrBysv-EGfgqG7SxUhN#scrollTo=V7ELesz1kVQo

To Reproduce


The shared colab is not going to be able to replicate the bug since it needs 'dp' on multiple gpus, but it should give an idea of when the error occurs. So setting

        num_gpus=4,
        accelerator="dp",

in the Trainer and then using a metric should bring up the issue. I have tested it with Accuracy but other users in the Slack channel reported it for other metrics such as Precision or Recall.

Expected behavior

The devices should be the same when the values are added together. I am not sure of which would be the correct approach, I have "brutely" solved it by:

        self.correct += torch.sum(preds.cuda(self.correct.device.index) == target.cuda(self.correct.device.index))
        self.total += target.cuda(self.correct.device.index).numel()

in the case of Accuracy, but that is quite an ugly way of dealing with it.
Update: Although this doesn't produce the error, the accuracy is not properly computed, as values get reset to 0 for some reason between steps.

Environment

  • CUDA:
    - GPU:
    - GeForce GTX 1080 Ti
    - GeForce GTX 1080 Ti
    - GeForce GTX 1080 Ti
    - GeForce GTX 1080 Ti
    - available: True
    - version: 10.2
  • Packages:
    - numpy: 1.19.2
    - pyTorch_debug: False
    - pyTorch_version: 1.6.0
    - pytorch-lightning: 1.0.3
    - tqdm: 4.50.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor:
    - python: 3.8.5
    - version: #1 SMP Debian 4.19.152-1 (2020-10-18)
DP Metrics enhancement help wanted

Most helpful comment

Just a small update:
It seems that the pitfall is in fact that we use self.register_buffer for the internal states in metrics. They are also making troubles in ddp mode since the buffer on rank 0 in each forward pass is overwriting the buffer on all other ranks, leading to wrong result in the. I am trying at the moment to come up with a solution for this.

That said, another problem with metrics in dp mode have come to my attention. Since dp is creating and destroying replicas of the model on each forward call, the internal state of metrics will be destroyed before we have a chance to accumulate them over the different devices. Therefore, until we implement some kind of state maintenance in dp (PR: #1895), the only way forward right now is (thanks @marrrcin for the workaround):

  • return preds,target in <mode>_step (<mode> either training, val or test)
  • call the metric in <mode>_step_end

All 11 comments

If you do metrics update in <step name>_step_end() methods it will work correctly.

@marrrcin could you provide a bit more detail? Like how and where do you explicitly call metric? (example)
I would like to get to the bottom of this.

So according to a conversation with him on slack, if the metric.forward() is called in the *_step_end() method (let's say self.acc(true,logits)) rather than in *_step() the error no longer happens.

This probably has to do with the fact that the output values from the step methods are properly gathered across devices into _step_end. Still I am not sure that this is the intended behavior of the metrics modules when using DP, since then they could only be called in train_step_end, validation_step_end and test_step_end.

Maybe @marrrcin can clarify when online, but that is what I understood.

@SkafteNicki
so my workaround is the following:
In my LightningModules __init__ I have:

        METRICS_SUFFIX = "_metrics"
        TRAINING = "train"
        TESTING = "test"
        VALIDATION = "val"
        metrics_factory = lambda: nn.ModuleList(
            [
                pl.metrics.Accuracy(),
                pl.metrics.Precision(hparams.num_classes),
                pl.metrics.Recall(hparams.num_classes),
                pl.metrics.Fbeta(hparams.num_classes),
            ]
        )
        self.metrics: nn.ModuleDict = nn.ModuleDict(
            {
                # you cannot have name `train` in ModuleDict, because nn.Module has function called `train`
                TRAINING + METRICS_SUFFIX: metrics_factory(),
                VALIDATION + METRICS_SUFFIX: metrics_factory(),
                TESTING + METRICS_SUFFIX: metrics_factory(),
            }
        )

Then, some utility functions:

    @staticmethod
    def _get_metric_name(m: pl.metrics.Metric):
        return m.__class__.__name__.lower()

    def update_metrics(
        self,
        step_name,
        pred_y: torch.Tensor,
        true_y: torch.Tensor,
        log_metrics=True,
    ):
        for metric in self.metrics[step_name + METRICS_SUFFIX]:
            m = metric(pred_y, true_y)
            if log_metrics:
                self.log(
                    f"{self._get_metric_name(metric)}/{step_name}",
                    m,
                    on_step=False,
                    on_epoch=True,
                )

Then, from *_step I return dicts with: _loss, pred_y, true_y_ keys.
Lastly, in *_step_end I call the update:

    def training_step_end(self, outputs: dict) -> torch.Tensor:
        loss = outputs["loss"].mean()

        self.update_metrics(TRAINING, outputs["pred_y"], outputs["true_y"])
# etc...

Seems to be a duplicate of / closely related to #4073, you might want to try the fix in #4138.

@EspenHa I agree that the error is the same, but this happens long before trying to log the metrics. Actually, it has nothing to do with the rest of lightning, but a general incompatibility between lightning metrics and DataParallel:

from pytorch_lightning import metrics
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

metric = metrics.Accuracy()
metric_dp = torch.nn.DataParallel(metric)
metric_dp.to(device)
pred = torch.randint(2, size=(10,))
target = torch.randint(2, size=(10,))
val = metric_dp(pred, target)

Even this small example fails with the same error. The core of the problem seems to have something to do with how we register the state as a buffer. As far as I understand, if we use self.register_buffer (which we do) pytorch should move the states to the correct devices, but this does not seems to happen.

Just a small update:
It seems that the pitfall is in fact that we use self.register_buffer for the internal states in metrics. They are also making troubles in ddp mode since the buffer on rank 0 in each forward pass is overwriting the buffer on all other ranks, leading to wrong result in the. I am trying at the moment to come up with a solution for this.

That said, another problem with metrics in dp mode have come to my attention. Since dp is creating and destroying replicas of the model on each forward call, the internal state of metrics will be destroyed before we have a chance to accumulate them over the different devices. Therefore, until we implement some kind of state maintenance in dp (PR: #1895), the only way forward right now is (thanks @marrrcin for the workaround):

  • return preds,target in <mode>_step (<mode> either training, val or test)
  • call the metric in <mode>_step_end

@teddykoker @ananyahjha93 please take a look as well

This is currently not supported, until we have stateful DP, see https://github.com/PyTorchLightning/pytorch-lightning/pull/1895

@ananyahjha93 what should be next steps? should we change the docs?

@edenlightning we need to get the state maintenance for DP in before we tackle this. Even if it doesn't solve this completely, that PR has been lying around for a few months.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

Vichoko picture Vichoko  路  3Comments

monney picture monney  路  3Comments

baeseongsu picture baeseongsu  路  3Comments

maxime-louis picture maxime-louis  路  3Comments

jcreinhold picture jcreinhold  路  3Comments