Apply the following diff to today's master branch of maskrcnn-benchmark.
diff --git i/maskrcnn_benchmark/data/collate_batch.py w/maskrcnn_benchmark/data/collate_batch.py
index a7f0341..c906712 100644
--- i/maskrcnn_benchmark/data/collate_batch.py
+++ w/maskrcnn_benchmark/data/collate_batch.py
@@ -14,7 +14,7 @@ class BatchCollator(object):
def __call__(self, batch):
transposed_batch = list(zip(*batch))
- images = to_image_list(transposed_batch[0], self.size_divisible)
+ images = transposed_batch[0]
targets = transposed_batch[1]
img_ids = transposed_batch[2]
return images, targets, img_ids
diff --git i/maskrcnn_benchmark/engine/trainer.py w/maskrcnn_benchmark/engine/trainer.py
index 38a9e52..be0c9f1 100644
--- i/maskrcnn_benchmark/engine/trainer.py
+++ w/maskrcnn_benchmark/engine/trainer.py
@@ -60,12 +60,26 @@ def do_train(
scheduler.step()
- images = images.to(device)
+ from maskrcnn_benchmark.structures.image_list import to_image_list
+ from torch.nn.parallel.scatter_gather import scatter
+
+ orig_images = images
+
+ images = [scatter(im, [dist.get_rank()])[0] for im in images] # this fails
+ #images = [im.cuda() for im in images] # this works
+
+ for orig_im, im in zip(orig_images, images):
+ diff = (orig_im - im.cpu()).abs().max()
+ assert diff < 0.01, diff
+ images = to_image_list(images, 32)
+
targets = [target.to(device) for target in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
+ if torch.isnan(losses).any():
+ raise FloatingPointError()
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = reduce_loss_dict(loss_dict)
Before the diff, the code (1) pad images to imagelist in dataloader (2) copy the imagelist to GPU.
After the diff, the code (1) does not pad images in dataloader (2) copy the images to GPU individually (3) pad GPU images to imagelist.
The two should have equivalent behavior, and indeed it works when I use .cuda() to do the H2D copy.
However it fails when I use scatter()[0] to do the H2D copy (which is what DistributedDataParallel would use).
If I ran this command on 2 GPUs, after applying the diff:
python -m torch.distributed.launch --nproc_per_node=2 tools/train_net.py --config-file configs/e2e_mask_rcnn_R_50_FPN_1x.yaml SOLVER.IMS_PER_BATCH 4 SOLVER.BASE_LR 0.0001
I observed two types of failures after dozens of iterations:
assert diff < 0.01 failed, which means the data was different after the copy. File "/maskrcnn-benchmark/maskrcnn_benchmark/engine/trainer.py", line 73, in do_train
assert diff < 0.01, diff
AssertionError: tensor(139.0213)
.cuda() to do the copy.It seems to be a bug in pytorch's scatter however I found no clues there. I'm also unable to simplify the repro (simplification make the bug disappear) so I posted it here.
PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 18.04.1 LTS
GCC version: (GCC) 5.3.0
CMake version: version 3.12.2
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100
Nvidia driver version: 410.79
cuDNN version: Could not collect
Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.3.1
[pip3] numpy==1.16.1
[pip3] numpydoc==0.7.0
[pip3] torch==1.0.0
[pip3] torchvision==0.2.1
[conda] blas 1.0 mkl
[conda] mkl 2019.1 144
[conda] mkl-include 2019.1 144
[conda] mkl-service 1.1.2 py36he904b0f_5
[conda] mkl_fft 1.0.6 py36hd81dba3_0
[conda] mkl_random 1.0.2 py36hd81dba3_0
[conda] pytorch 1.0.0 py3.6_cuda9.0.176_cudnn7.4.1_1 pytorch
[conda] torchvision 0.2.1 py_2 pytorch
Hi,
There seems to be a problem with the streams in scatter, maybe in the stream synchronization. If you comment out this line in PyTorch, then the code runs fine on my end.
A few wild guesses: this is either:
cuda_guard.reset_stream, which is only used in this place in the codebase (also in Gloo backend, but I've never really used gloo).to(..., non_blocking=True)@mrshenli this is probably something related to https://github.com/pytorch/pytorch/pull/16966
@fmassa I agree with you, as you pointed out that the error disappears after commenting out _get_stream. Before pytorch/pytorch#16966, copy are always synchronized on the destination devices default stream. Now, it is synchronized on the destination devices' current stream (behavior on src device is unchanged), which means:
I will look into this.
@ppwwyyxx
~Is orig_images on cpu or gpu in this failed case?~ should be on cpu.
@mrshenli any updates on this issue?
@ppwwyyxx I was working on another urgent task. Back to this one now. :)
hmm, I tried both the following code and tools/train_net.py with your patch with DDP on devgpu, but cannot reproduce it (will try devfair next). May I ask if you can consistently reproduce this error? @ppwwyyxx
import argparse
from torch.nn.parallel.scatter_gather import scatter
import torch
import torch.distributed as dist
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=0)
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend="nccl", init_method="env://")
print ("=== starting ", args.local_rank, ", ", dist.get_rank())
images = [torch.randn(3, 128, 128) for _ in range(10000)]
orig_images = images
images = [scatter(im, [dist.get_rank()])[0] for im in images]
for orig_im, im in zip(orig_images, images):
diff = (orig_im - im.cpu()).abs().max()
if diff != 0:
print(diff)
@fmassa are you able to consistently reproduce this error with PyTorch master branch?
The instruction to reproduce is posted in the original issue. I cannot reproduce it with a simplified example. But it can be reliably reproduced with the original instruction.
Thanks, let me try
great, hit the bug using detectron2
Fixed by pytorch/pytorch/pull/18465
Most helpful comment
Fixed by pytorch/pytorch/pull/18465