I would like to use BART in FP16 mode, but it seems impossible for now :
config = BartConfig(vocab_size=50264, output_past=True)
model = AutoModelWithLMHead.from_pretrained('bart-large-cnn', config=config).cuda().half()
tokenizer = AutoTokenizer.from_pretrained('bart-large-cnn')
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
generated_ids = model.generate(inputs['input_ids'].cuda(), attention_mask=inputs['attention_mask'].cuda(), num_beams=4, max_length=5)
File "/data/user/.venv/bartqg/lib/python3.6/site-packages/transformers/modeling_bart.py", line 647, in forward
attn_output = torch.bmm(attn_probs, v)
RuntimeError: Expected object of scalar type Float but got scalar type Half for argument #2 'mat2' in call to _th_bmm
@sshleifer Do you plan to implement a FP16-friendly version of BART ?
Not on my roadmap just yet, but I would definitely consider it if there were lots of demand. Since we only have inference code right now, the benefit seems marginal.
@BramVanroy Should this issue be closed ?
FP16 is not implemented yet. And the wontfix label is clear.
Keeping the issue open may make it easier for people to find it and show their potential interest in FP16.
This should not be closed indeed.
@sshleifer, we intend all the models to be compatible with FP16, this is the direction the field is going and with the Volta-level GPU being widespread now, there is less and less reason not to use mixed-precision fine-tuning (half memory and significantly faster).
Yep, on it!
Hi, @sshleifer. Thank you so much for your effort on BART. I encountered the same fp16 issues today. The current BART code can be trained (without fp16) using the run_glue script in: https://github.com/huggingface/transformers/blob/master/examples/run_glue.py
So, it will be really nice if the fp16 training can also work out.
My bad, I thought @sshleifer's labeling was a note that he isn't planning to change anything wontfix, so no future updates would be possible and then I closed it. Will keep that in mind for the future.
No bad
@sshleifer for the moment, please ping me with DM before adding "wontfix" labels to issues, thanks.
Most helpful comment
This should not be closed indeed.
@sshleifer, we intend all the models to be compatible with FP16, this is the direction the field is going and with the Volta-level GPU being widespread now, there is less and less reason not to use mixed-precision fine-tuning (half memory and significantly faster).