Pytorch-lightning: Native Amp Support

Created on 2 Apr 2020  Â·  7Comments  Â·  Source: PyTorchLightning/pytorch-lightning

Native automatic mixed precision support (torch.cuda.amp) is finally merged:
https://pytorch.org/docs/master/amp.html
https://pytorch.org/docs/master/notes/amp_examples.html
Apex Amp has many known pain points (extension builds, forward/backward compatibilty, DataParallel support, flaky checkpointing, i don’t even know if it can be hacked to handle double backward/gradient penalty, others…). torch.cuda.amp fixes all these, the interface is more flexible and intuitive, and the tighter integration brings more future performance optimizations into scope.

If you want to talk about adding torch.cuda.amp to Lightning, with an eye towards it becoming the true source of mixed precision and replacing Apex, message me on Pytorch slack anytime. I pinged you there as well, but I’m not sure if you monitor it habitually.

enhancement help wanted

All 7 comments

I think the torch.cuda.amp API is a much better fit for Lightning because its style is more functional (functional as in, it doesn't statefully alter anything outside itself). The necessary torch.cuda.amp calls could be contained entirely within trainer.fit() without any silent/weird effects elsewhere.

this is awesome. will definitely add! eta on the next pt release?
we can add forward compatibility.

@mcarilli does it still have the issues with saving/loading weights with the loss scaling factor?

@PyTorchLightning/core-contributors anyone interested in making this change?

one key consideration is saving/loading weights when amp scales the loss.

Yes, bitwise accurate saving/restoring is supported. You just need to call your GradScaler instance's state_dict() and load_state_dict() alongside the usual model/optimizer state_dict/load_state_dict calls.

@mcarilli any chance you'd be interested in submitting the PR?
I might be able to get to it by early this week, but it'd be great to have in 0.7.2 which is coming early next week.

was going to add checks like:
```python
if pytorch.__version__ >= 1.6:
# new amp stuff
else:
# old amp stuff

hmm i don't know the lightning codebase at all, aside from the interface. It would take me longer than early next week to be sure I was making the right changes in the right places. The version is a more complex string though, so I'd use something like

version_ge_16 = False
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 6):
    version_ge_16 = True

not sure about the particular condition if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 6): but yes with parsing the version you can set this king of default env

Happy to review in-progress PRs though.

One key point is that torch.cuda.amp.autocast locally (in its context regions) enables Apex "O1"-like casting behavior. It casts inputs on the fly as they enter certain functions, it doesn't need to touch the model or the optimizer at all, nor does it need them to change. You shouldn't manually call .half() on the model or input data. (scaler.step(optimizer) only decides to call optimizer.step() or not, it doesn't change the optimizer in any stateful way).

Also that versioning condition is based on what works for us. The particular number (1.6+) is a decent criterion for native amp availability, the window of commits with torch.__version__ = 1.6.xyz that don't yet have autocast and GradScaler is small.

You could sidestep __version__ parsing entirely and check for full native amp support via

has_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")

Now that I mention it that's probably better.

Was this page helpful?
0 / 5 - 0 ratings