Transformers: BART FP16

Created on 4 Mar 2020  路  8Comments  路  Source: huggingface/transformers

馃殌 Feature request

I would like to use BART in FP16 mode, but it seems impossible for now :

config = BartConfig(vocab_size=50264, output_past=True)
model = AutoModelWithLMHead.from_pretrained('bart-large-cnn', config=config).cuda().half()
tokenizer = AutoTokenizer.from_pretrained('bart-large-cnn')
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
generated_ids = model.generate(inputs['input_ids'].cuda(), attention_mask=inputs['attention_mask'].cuda(), num_beams=4, max_length=5)

File "/data/user/.venv/bartqg/lib/python3.6/site-packages/transformers/modeling_bart.py", line 647, in forward
attn_output = torch.bmm(attn_probs, v)
RuntimeError: Expected object of scalar type Float but got scalar type Half for argument #2 'mat2' in call to _th_bmm

@sshleifer Do you plan to implement a FP16-friendly version of BART ?

Most helpful comment

This should not be closed indeed.

@sshleifer, we intend all the models to be compatible with FP16, this is the direction the field is going and with the Volta-level GPU being widespread now, there is less and less reason not to use mixed-precision fine-tuning (half memory and significantly faster).

All 8 comments

Not on my roadmap just yet, but I would definitely consider it if there were lots of demand. Since we only have inference code right now, the benefit seems marginal.

@BramVanroy Should this issue be closed ?

FP16 is not implemented yet. And the wontfix label is clear.

Keeping the issue open may make it easier for people to find it and show their potential interest in FP16.

This should not be closed indeed.

@sshleifer, we intend all the models to be compatible with FP16, this is the direction the field is going and with the Volta-level GPU being widespread now, there is less and less reason not to use mixed-precision fine-tuning (half memory and significantly faster).

This can probably be fixed by changing the torch.float32 casting here to a cast to the type of attn_weights like it's done in the original fairseq code here.

Do you mind fixing this and testing the failing script posted in the issue @sshleifer?

Yep, on it!

Hi, @sshleifer. Thank you so much for your effort on BART. I encountered the same fp16 issues today. The current BART code can be trained (without fp16) using the run_glue script in: https://github.com/huggingface/transformers/blob/master/examples/run_glue.py
So, it will be really nice if the fp16 training can also work out.

My bad, I thought @sshleifer's labeling was a note that he isn't planning to change anything wontfix, so no future updates would be possible and then I closed it. Will keep that in mind for the future.

No bad

@sshleifer for the moment, please ping me with DM before adding "wontfix" labels to issues, thanks.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

iedmrc picture iedmrc  路  3Comments

fyubang picture fyubang  路  3Comments

alphanlp picture alphanlp  路  3Comments

zhezhaoa picture zhezhaoa  路  3Comments

lemonhu picture lemonhu  路  3Comments