Apex: torch.autograd.grad()

Created on 21 Mar 2019  路  8Comments  路  Source: NVIDIA/apex

I'm using a gradient penalty, something like the following:

y = model(x)
loss = some_loss_func(y)

gradients = torch.autograd.grad(
  outputs=y,
  inputs=x,
  grad_outputs=y.new_ones(y.size()),
  create_graph=True,
  retain_graph=True,
  only_inputs=True)[0]
gradients = gradients.view(gradients.size(0), -1)
penalty = (gradients.norm(2, dim=1) ** 2).mean()

loss += penalty
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
  scaled_loss.backward()

This results in the error RuntimeError: expected type torch.cuda.FloatTensor but got torch.cuda.HalfTensor on the .backward() call, in either O1 or O2 mode (but not O0 or O3). When I remove the gradient penalty, the code runs fine in all modes. I'm running on a single GPU.

Is this expected? If so, is there a suggested alternative?

Most helpful comment

I think I know what's happening. This is a nice one...
With both O1 and O2, batchnorm weights are kept in FP32, which is a requirement to enable cudnn batchnorm. In O1 batchnorm weights remain FP32 because all weights remain FP32. In O2 batchnorm weights remain FP32 because we explicitly special-case keeping batchnorm weights in FP32, while the rest of the model weights are cast to FP16. Cudnn batchnorm forward can handle FP16 inputs+FP32 weights without trouble, and cudnn batchnorm backward can handle FP16 incoming gradients+FP32 weights without trouble. However, when a backward pass with create_graph=True is underway, Pytorch falls back to a non-cudnn (native) implementation of batchnorm backward that is double-differentiable. This native backward implementation CANNOT handle a combination of FP16 incoming gradients + FP32 weights, which (I suspect) causes your error.

There are a couple of approaches that might help here. With O1, you can try registering batchnorm as blacklist function, which will ensure its inputs and outputs (and therefore its incoming gradients during backward) are cast to FP32:

amp.register_float_function(torch, 'batch_norm')
model, optimizer = amp.initialize(model. optimizer, opt_level="O1")

Alternatively, with O2, you can work around by supplying the override keep_batchnorm_fp32=False, but this is less safe numerically imo.

All 8 comments

I discovered that this issue only occurs when batch normalization is used in the model. Otherwise it runs without any errors in both O1 and O2 modes.

Can you post the full backtrace of the errors you see?

Sure, here it is. I'm training a GAN, but I think this is separate from the issues of loss scaling noted elsewhere.

I can work on making a minimal example if that would be helpful.

RuntimeError                              Traceback (most recent call last)
~/projects/random_ai/gan/wgan/train.py in <module>()
     55 
     56 trainer = trainer.Trainer(args, dataloader, generator, discriminator)
---> 57 trainer.train()
     58 ##
     59 

~/projects/random_ai/gan/wgan/trainer.py in train(self)
    142 
    143             for _ in range(Diters):
--> 144                 self.train_discriminator()
    145 
    146             self.train_generator()

~/projects/random_ai/gan/wgan/trainer.py in train_discriminator(self)
     95         errD += GRAD_PENALTY_SCALE * penalty
     96         with amp.scale_loss(errD, self.optimizer_D) as scaled_loss:
---> 97             scaled_loss.backward()
     98         self.optimizer_D.step()
     99 

~/.pyenv/versions/3.6.4/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    100                 products. Defaults to ``False``.
    101         """
--> 102         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    103 
    104     def register_hook(self, hook):
~/.pyenv/versions/3.6.4/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     88     Variable._execution_engine.run_backward(
     89         tensors, grad_tensors, retain_graph, create_graph,
---> 90         allow_unreachable=True)  # allow_unreachable flag
     91 
     92 
RuntimeError: expected type torch.cuda.FloatTensor but got torch.cuda.HalfTensor

Here's a minimal example that throws the same error.

Hmm, I'm not in eval() mode, or using affine=False.

I think I know what's happening. This is a nice one...
With both O1 and O2, batchnorm weights are kept in FP32, which is a requirement to enable cudnn batchnorm. In O1 batchnorm weights remain FP32 because all weights remain FP32. In O2 batchnorm weights remain FP32 because we explicitly special-case keeping batchnorm weights in FP32, while the rest of the model weights are cast to FP16. Cudnn batchnorm forward can handle FP16 inputs+FP32 weights without trouble, and cudnn batchnorm backward can handle FP16 incoming gradients+FP32 weights without trouble. However, when a backward pass with create_graph=True is underway, Pytorch falls back to a non-cudnn (native) implementation of batchnorm backward that is double-differentiable. This native backward implementation CANNOT handle a combination of FP16 incoming gradients + FP32 weights, which (I suspect) causes your error.

There are a couple of approaches that might help here. With O1, you can try registering batchnorm as blacklist function, which will ensure its inputs and outputs (and therefore its incoming gradients during backward) are cast to FP32:

amp.register_float_function(torch, 'batch_norm')
model, optimizer = amp.initialize(model. optimizer, opt_level="O1")

Alternatively, with O2, you can work around by supplying the override keep_batchnorm_fp32=False, but this is less safe numerically imo.

Thanks for the careful analysis! Your suggestion of registering batch norm as a fp32-only layer for O1 does seem to have fixed the issue.

@mcarilli Thanks for your analysis! It's OK for my work.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

jbraeburn picture jbraeburn  路  4Comments

rmrao picture rmrao  路  4Comments

jah3xc picture jah3xc  路  4Comments

lemonhu picture lemonhu  路  3Comments

ccoulombe picture ccoulombe  路  3Comments