Apex: torch.inverse needs to be blacklisted

Created on 29 Mar 2020  路  4Comments  路  Source: NVIDIA/apex

SUMMARY

The method torch.inverse does not support fp16 arguments so it needs to be blacklisted.

DETAILS

Running the following

from apex import amp
import torch

class Foo(torch.nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.A = torch.nn.Parameter(torch.rand(dim, dim))

    def forward(self, X):
        AX = torch.matmul(self.A, X)
        return torch.inverse(AX)

dim = 3
model = Foo(dim).cuda()
optimizer = torch.optim.AdamW(model.parameters())
model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=False)

X = torch.rand(dim, dim).cuda()
Y = model(X)

Results in the error

RuntimeError                              Traceback (most recent call last)
<ipython-input-4-46f64e772005> in <module>
      1 X = torch.rand(dim, dim).cuda()
----> 2 Y = model(X)
      3 Y

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

<ipython-input-2-d9366f5dfdeb> in forward(self, X)
      6     def forward(self, X):
      7         AX = torch.matmul(self.A, X)
----> 8         return torch.inverse(AX)

RuntimeError: "inverse_cuda" not implemented for 'Half'

Most helpful comment

@DeepVoltaire @prajjwal1 Still have this issue. Has anyone gotten around it?

All 4 comments

This issue is blocking using mixed precision training in combination with affine transformations (here) on the GPU when using Kornia. Or is there a workaround I am not aware of?

The only workaround to use torch.inverse is to not use fp16 training mode since using mixed precision would make the torch.inverse operation unstable.

I think with torch.cuda.amp, this will automatically be taken care of and is the recommended way to proceed.

@DeepVoltaire @prajjwal1 Still have this issue. Has anyone gotten around it?

Was this page helpful?
0 / 5 - 0 ratings