When using a metric such as Accuracy from pytorch_lightning.metrics in machine with 4 GPU and in 'dp' mode, there is an error due to accumulating the metric in different devices. In the case of Accuracy, in line:
https://github.com/PyTorchLightning/pytorch-lightning/blob/c8ccec7a02c53ed38af6ef7193232426384eee4a/pytorch_lightning/metrics/classification/accuracy.py#L108
The arguments in torch.sum are in the same device the metric is been called from, but the self.correct is in a different one. The traceback is as follows:
self.accuracy_val(y_hat, y)
File "/home/***/.conda/envs/***/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/***/.conda/envs/***/lib/python3.8/site-packages/pytorch_lightning/metrics/metric.py", line 153, in forward
self.update(*args, **kwargs)
File "/home/***/.conda/envs/***/lib/python3.8/site-packages/pytorch_lightning/metrics/metric.py", line 199, in wrapped_func
return update(*args, **kwargs)
File "/home/***/.conda/envs/***/lib/python3.8/site-packages/pytorch_lightning/metrics/classification/accuracy.py", line 109, in update
self.correct += torch.sum(preds == target)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!
https://colab.research.google.com/drive/1zcU1ADuHZj82clrBysv-EGfgqG7SxUhN#scrollTo=V7ELesz1kVQo
The shared colab is not going to be able to replicate the bug since it needs 'dp' on multiple gpus, but it should give an idea of when the error occurs. So setting
num_gpus=4,
accelerator="dp",
in the Trainer and then using a metric should bring up the issue. I have tested it with Accuracy but other users in the Slack channel reported it for other metrics such as Precision or Recall.
The devices should be the same when the values are added together. I am not sure of which would be the correct approach, I have "brutely" solved it by:
self.correct += torch.sum(preds.cuda(self.correct.device.index) == target.cuda(self.correct.device.index))
self.total += target.cuda(self.correct.device.index).numel()
in the case of Accuracy, but that is quite an ugly way of dealing with it.
Update: Although this doesn't produce the error, the accuracy is not properly computed, as values get reset to 0 for some reason between steps.
If you do metrics update in <step name>_step_end() methods it will work correctly.
@marrrcin could you provide a bit more detail? Like how and where do you explicitly call metric? (example)
I would like to get to the bottom of this.
So according to a conversation with him on slack, if the metric.forward() is called in the *_step_end() method (let's say self.acc(true,logits)) rather than in *_step() the error no longer happens.
This probably has to do with the fact that the output values from the step methods are properly gathered across devices into _step_end. Still I am not sure that this is the intended behavior of the metrics modules when using DP, since then they could only be called in train_step_end, validation_step_end and test_step_end.
Maybe @marrrcin can clarify when online, but that is what I understood.
@SkafteNicki
so my workaround is the following:
In my LightningModules __init__ I have:
METRICS_SUFFIX = "_metrics"
TRAINING = "train"
TESTING = "test"
VALIDATION = "val"
metrics_factory = lambda: nn.ModuleList(
[
pl.metrics.Accuracy(),
pl.metrics.Precision(hparams.num_classes),
pl.metrics.Recall(hparams.num_classes),
pl.metrics.Fbeta(hparams.num_classes),
]
)
self.metrics: nn.ModuleDict = nn.ModuleDict(
{
# you cannot have name `train` in ModuleDict, because nn.Module has function called `train`
TRAINING + METRICS_SUFFIX: metrics_factory(),
VALIDATION + METRICS_SUFFIX: metrics_factory(),
TESTING + METRICS_SUFFIX: metrics_factory(),
}
)
Then, some utility functions:
@staticmethod
def _get_metric_name(m: pl.metrics.Metric):
return m.__class__.__name__.lower()
def update_metrics(
self,
step_name,
pred_y: torch.Tensor,
true_y: torch.Tensor,
log_metrics=True,
):
for metric in self.metrics[step_name + METRICS_SUFFIX]:
m = metric(pred_y, true_y)
if log_metrics:
self.log(
f"{self._get_metric_name(metric)}/{step_name}",
m,
on_step=False,
on_epoch=True,
)
Then, from *_step I return dicts with: _loss, pred_y, true_y_ keys.
Lastly, in *_step_end I call the update:
def training_step_end(self, outputs: dict) -> torch.Tensor:
loss = outputs["loss"].mean()
self.update_metrics(TRAINING, outputs["pred_y"], outputs["true_y"])
# etc...
Seems to be a duplicate of / closely related to #4073, you might want to try the fix in #4138.
@EspenHa I agree that the error is the same, but this happens long before trying to log the metrics. Actually, it has nothing to do with the rest of lightning, but a general incompatibility between lightning metrics and DataParallel:
from pytorch_lightning import metrics
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
metric = metrics.Accuracy()
metric_dp = torch.nn.DataParallel(metric)
metric_dp.to(device)
pred = torch.randint(2, size=(10,))
target = torch.randint(2, size=(10,))
val = metric_dp(pred, target)
Even this small example fails with the same error. The core of the problem seems to have something to do with how we register the state as a buffer. As far as I understand, if we use self.register_buffer (which we do) pytorch should move the states to the correct devices, but this does not seems to happen.
Just a small update:
It seems that the pitfall is in fact that we use self.register_buffer for the internal states in metrics. They are also making troubles in ddp mode since the buffer on rank 0 in each forward pass is overwriting the buffer on all other ranks, leading to wrong result in the. I am trying at the moment to come up with a solution for this.
That said, another problem with metrics in dp mode have come to my attention. Since dp is creating and destroying replicas of the model on each forward call, the internal state of metrics will be destroyed before we have a chance to accumulate them over the different devices. Therefore, until we implement some kind of state maintenance in dp (PR: #1895), the only way forward right now is (thanks @marrrcin for the workaround):
preds,target in <mode>_step (<mode> either training, val or test)<mode>_step_end @teddykoker @ananyahjha93 please take a look as well
This is currently not supported, until we have stateful DP, see https://github.com/PyTorchLightning/pytorch-lightning/pull/1895
@ananyahjha93 what should be next steps? should we change the docs?
@edenlightning we need to get the state maintenance for DP in before we tackle this. Even if it doesn't solve this completely, that PR has been lying around for a few months.
Most helpful comment
Just a small update:
It seems that the pitfall is in fact that we use
self.register_bufferfor the internal states in metrics. They are also making troubles in ddp mode since the buffer on rank 0 in each forward pass is overwriting the buffer on all other ranks, leading to wrong result in the. I am trying at the moment to come up with a solution for this.That said, another problem with metrics in dp mode have come to my attention. Since dp is creating and destroying replicas of the model on each forward call, the internal state of metrics will be destroyed before we have a chance to accumulate them over the different devices. Therefore, until we implement some kind of state maintenance in dp (PR: #1895), the only way forward right now is (thanks @marrrcin for the workaround):
preds,targetin<mode>_step(<mode>eithertraining,valortest)<mode>_step_end