I'm running the following on 4 GPUs:
model = Resnet50()
model = model.cuda()
criterion = nn.CrossEntropyLoss(reduction='mean').cuda()
optimizer = torch.optim.SGD(model.parameters(), 0.001)
model, optimizer = amp.initialize(model, optimizer, opt_level='O3', keep_batchnorm_fp32=False)
model = torch.nn.DataParallel(model)
And I get the following error:
Selected optimization level O3: Pure FP16 training.
Defaults for this optimization level are:
enabled : True
opt_level : O3
cast_model_type : torch.float16
patch_torch_functions : False
keep_batchnorm_fp32 : False
master_weights : False
loss_scale : 1.0
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled : True
opt_level : O3
cast_model_type : torch.float16
patch_torch_functions : False
keep_batchnorm_fp32 : False
master_weights : False
loss_scale : 1.0
lr: 0.1 wd 0.0001
Traceback (most recent call last):
File "main.py", line 559, in <module>
main()
File "main.py", line 555, in main
train(train_loader, val_loader, model, criterion, optimizer, start_epoch, best_acc, args)
File "main.py", line 409, in train
output = model(input_var, epoch=epoch, i=i)
File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
result = self.forward(*input, **kwargs)
File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
output.reraise()
File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/_utils.py", line 369, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
output = module(*input, **kwargs)
File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
result = self.forward(*input, **kwargs)
File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/apex/amp/_initialize.py", line 194, in new_fwd
**applier(kwargs, input_caster))
File "/home/michael/noisynet/models/resnet.py", line 161, in forward
x = self.conv1(x)
File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
result = self.forward(*input, **kwargs)
File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 343, in forward
return self.conv2d_forward(input, self.weight)
File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 340, in conv2d_forward
self.padding, self.dilation, self.groups)
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)
This should be an expected behavior. The PyTorch DataParallel only works with O1. You may refer to #269 for more details.
I found the root cause: forward must be patched after DataParallel(...) call (because otherwise the patched method refers the old model object and not the dynamically created replica). Maybe some other patching way exists that would work fine with DP, but definitely not the straightforward way in https://github.com/NVIDIA/apex/blob/master/apex/amp/_initialize.py#L201
The workaround I found:
model = apex.amp.initialize(torch.nn.Sequential(model), opt_level = 'O2')[0]
model = torch.nn.DataParallel(model, device_ids = args.devices)
model.forward = lambda *args, old_fwd = model.forward, input_caster = lambda tensor: tensor.to(apex.amp._amp_state.opt_properties.options['cast_model_type']), output_caster = lambda tensor: tensor.to(apex.amp._amp_state.opt_properties.options['cast_model_outputs'] if apex.amp._amp_state.opt_properties.options.get('cast_model_outputs') is not None else torch.float32), **kwargs: apex.amp._initialize.applier(old_fwd(*apex.amp._initialize.applier(args, input_caster), **apex.amp._initialize.applier(kwargs, input_caster)), output_caster)
@jianchao-li
This is still very useful information and I haven't been ignoring it, but to be honest I'm probably not going to implement a fix in Apex soon. My absolute top priority right now is getting automatic mixed precision into Pytorch natively, which will eliminate all extension building/version matching issues. I'm taking care to ensure the native integration will support DistributedDataParallel, DataParallel, and model parallel usage. We are targeting the 1.5 release:
https://github.com/pytorch/pytorch/issues/25081
Gradient scaling and autocasting will be independently-usable components.
The gradient scaling PR is mature, awaiting final documentation review:
https://github.com/pytorch/pytorch/pull/26512
The autocasting PR is about 3/4 done in terms of op coverage:
https://github.com/pytorch/pytorch/pull/29552
Autocasting will likely be exposed via a context manager that can be used to locally enable/disable mixed precision for any desired regions of the model.
If you are having problems with the current incarnation of Apex, my best advice is to wait for the PRs to be merged. Getting native mixed precision support as soon as possible is the best path forward for everyone IMO.
@mcarilli Btw Is O2/O3 supported in PyTorch autocast? I had colleagues mentioned that they saw no RAM decrease when using PyTorch core autocast, as if activations were still stored in fp32
Most helpful comment
This is still very useful information and I haven't been ignoring it, but to be honest I'm probably not going to implement a fix in Apex soon. My absolute top priority right now is getting automatic mixed precision into Pytorch natively, which will eliminate all extension building/version matching issues. I'm taking care to ensure the native integration will support DistributedDataParallel, DataParallel, and model parallel usage. We are targeting the 1.5 release:
https://github.com/pytorch/pytorch/issues/25081
Gradient scaling and autocasting will be independently-usable components.
The gradient scaling PR is mature, awaiting final documentation review:
https://github.com/pytorch/pytorch/pull/26512
The autocasting PR is about 3/4 done in terms of op coverage:
https://github.com/pytorch/pytorch/pull/29552
Autocasting will likely be exposed via a context manager that can be used to locally enable/disable mixed precision for any desired regions of the model.
If you are having problems with the current incarnation of Apex, my best advice is to wait for the PRs to be merged. Getting native mixed precision support as soon as possible is the best path forward for everyone IMO.