Hi,
I tried fp16 + xlnet, it did not work.
when I set opt_level='O2', the memory was half, but it was much slower than fp32.
when I set opt_level='O1', the memory was original, and it has similar speed with fp32.
Environment: v100, cuda, 10.1, torch 1.1
The environment is ok, because I tried bert + fp16 and it was much faster than fp32.
I thought it is the problem of torch.einsum, but I am not that sure.
I used the code here to test: https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_squad.py
XLNet makes heavy use of torch.einsum()
but I'm not sure this method is fp16 compatible.
It's also quite slow currently so maybe in the mid/long-term it would be good to change these einsum to standard matmul. I won't have time to do that very soon though.
As as a suggestion, you can add apex.amp.register_half_function(torch, 'einsum')
somewhere near the top of your driver script (examples/run_squad.py for instance).
This forces amp
to cast the inputs to einsum to torch.half
before executing, allowing you to get the perf benefits of fp16 + TensorCores when appropriate.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
Most helpful comment
As as a suggestion, you can add
apex.amp.register_half_function(torch, 'einsum')
somewhere near the top of your driver script (examples/run_squad.py for instance).This forces
amp
to cast the inputs to einsum totorch.half
before executing, allowing you to get the perf benefits of fp16 + TensorCores when appropriate.