Pytorch-lightning: NumpyMetric not mapping back to GPU in multi-GPU training

Created on 3 Aug 2020  路  3Comments  路  Source: PyTorchLightning/pytorch-lightning

馃悰 Bug

I created a NumpyMetric class for an involved metric that requires numpy operations; however, the metric fails when training on multiple GPUs. After some debugging, this appears to be due to the resulting tensor not being mapped back to the appropriate GPU (or any GPU for that matter).

To Reproduce

Steps to reproduce the behavior:

  1. Define a NumpyMetric class
class MyNumpyMetric(NumpyMetric):
    def forward(self, y_hat, y):
        # complicated numpy stuff (no calls to .cpu() or .cuda() or .to() or anything like that)
        return metric
  1. Instantiate it in the __init__ and validation_step of my PyTorchLightning module, e.g.,
class MyNetwork(pl.LightningModule):
    def __init__(self, args):
        # other init stuff
        self.my_metric = MyNumpyMetric('my_metric')

    def validation_step(self, batch, batch_idx):
        # other validation stuff
        my_metric = self.my_metric(y_hat, y)  # where y_hat and y are tensors, no .cpu(), .cuda(), .to() called on either
        out_dict = dict(val_my_metric=my_metric)
        return out_dict
  1. Run:
model = MyNetwork(args)
trainer = Trainer(
    benchmark=True,
    check_val_every_n_epoch=1,
    accumulate_grad_batches=1,
    min_epochs=n_epochs,
    max_epochs=n_epochs,
    fast_dev_run=False,
    gpus=2,
    distributed_backend='dp'
)
trainer.fit(model)
  1. See:
Traceback (most recent call last):
  File "./tiramisu3d.py", line 574, in <module>
    trainer.fit(model)
  File "/iacl/pg20/jacobr/miniconda3/envs/msseg-9.2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 997, in fit
    results = self.dp_train(model)
  File "/iacl/pg20/jacobr/miniconda3/envs/msseg-9.2/lib/python3.8/site-packages/pytorch_lightning/trainer/distrib_parts.py", line 270, in dp_train
    result = self.run_pretrain_routine(model)
  File "/iacl/pg20/jacobr/miniconda3/envs/msseg-9.2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in run_pretrain_routine
    eval_results = self._evaluate(model,
  File "/iacl/pg20/jacobr/miniconda3/envs/msseg-9.2/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 293, in _evaluate
    output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
  File "/iacl/pg20/jacobr/miniconda3/envs/msseg-9.2/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 444, in evaluation_forward
    output = model(*args)
  File "/iacl/pg20/jacobr/miniconda3/envs/msseg-9.2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/iacl/pg20/jacobr/miniconda3/envs/msseg-9.2/lib/python3.8/site-packages/pytorch_lightning/overrides/data_parallel.py", line 66, in forward
    return self.gather(outputs, self.output_device)
  File "/iacl/pg20/jacobr/miniconda3/envs/msseg-9.2/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 168, in gather
    return gather(outputs, output_device, dim=self.dim)
  File "/iacl/pg20/jacobr/miniconda3/envs/msseg-9.2/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 68, in gather
    res = gather_map(outputs)
  File "/iacl/pg20/jacobr/miniconda3/envs/msseg-9.2/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 61, in gather_map
    return type(out)(((k, gather_map([d[k] for d in outputs]))
  File "/iacl/pg20/jacobr/miniconda3/envs/msseg-9.2/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 61, in <genexpr>
    return type(out)(((k, gather_map([d[k] for d in outputs]))
  File "/iacl/pg20/jacobr/miniconda3/envs/msseg-9.2/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 55, in gather_map
    return Gather.apply(target_device, dim, *outputs)
  File "/iacl/pg20/jacobr/miniconda3/envs/msseg-9.2/lib/python3.8/site-packages/torch/nn/parallel/_functions.py", line 54, in forward
    assert all(map(lambda i: i.is_cuda, inputs))

Code sample

I will try to do this soon.

Expected behavior

I expected no error to occur. The documentation states: "[NumpyMetric] already handles DDP sync and input/output conversions." However, this doesn't appear to be the case in my implementation.

Environment

  • CUDA:

    • GPU:



      • Tesla M40 24GB


      • Tesla M40 24GB



    • available: True

    • version: 9.2

  • Packages:

    • numpy: 1.18.5

    • pyTorch_debug: False

    • pyTorch_version: 1.5.1

    • pytorch-lightning: 0.8.5

    • tensorboard: 2.3.0

    • tqdm: 4.48.0

  • System:

    • OS: Linux

    • architecture:



      • 64bit


      • ELF



    • processor: x86_64

    • python: 3.8.3

    • version: #1 SMP Wed Sep 26 11:06:22 UTC 2018

Additional context

PyTorch and PyTorch Lightning were installed with conda (along with all of the other packages).

I was able to work around this error by adding the following .to() call to the validation step:

def validation_step(self, batch, batch_idx):
    # other validation stuff
    my_metric = self.my_metric(y_hat, y)
    my_metric = my_metric.to(y_hat.device)
    out_dict = dict(val_my_metric=my_metric)
    return out_dict

I presume, however, that this is not the intended way to use the NumpyMetric class.

FWIW, I briefly looked at the code to see if I could just submit a PR with the fix (if this isn't user error), but it wasn't clear to me where the best places to look were. If you point me in the right direction, I might be able to submit a PR with the fix.

bug / fix help wanted

All 3 comments

Hi! thanks for your contribution!, great first issue!

Hi, good news, I believe I have fixed this issue already in #2657, at least it looks very similar. The fix is not released yet (but soon). So if you need it now, install Lightning from master branch. (your workaround is also fine)

@jcreinhold mind try master and feel free to reopen if needed 馃惏

Was this page helpful?
0 / 5 - 0 ratings

Related issues

mateuszpieniak picture mateuszpieniak  路  28Comments

dschaehi picture dschaehi  路  31Comments

mpariente picture mpariente  路  53Comments

dsuess picture dsuess  路  26Comments

sai-prasanna picture sai-prasanna  路  24Comments