Transformers: Remove hard-coded uses of float32 to fix mixed precision use in Distilbert

Created on 31 Aug 2020  路  4Comments  路  Source: huggingface/transformers

In this commit Remove hard-coded uses of float32 to fix mixed precision use, the mixed precision issue is fixed for modeling_tf_bert.py.

However, for modeling_tf_distilbert.py, the line 171 is not fixed yet, and we get

173 embeddings = inputs_embeds + position_embeddings # (bs, max_seq_length, dim)
--> 174 embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
InvalidArgumentError: cannot compute AddV2 as input #1(zero-based) was expected to be a bfloat16 tensor but is a float tensor

while using mixed_bfloat16 mixed precision with TPU.

A very quick fix is the same as the fix for modeling_tf_bert.py:

position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)

@schmidek

Most helpful comment

Hi @patrickvonplaten , OK, that would be my first contribution to transformers :)

All 4 comments

Indeed! Do you want to open a PR to fix this?

@LysandreJik

I can do that. However @patrickvonplaten has already self-assigned for this. How do you think, @patrickvonplaten?

Hey @chiapas, it would be great if you can open a PR for it :-)

Hi @patrickvonplaten , OK, that would be my first contribution to transformers :)

Was this page helpful?
0 / 5 - 0 ratings

Related issues

siddsach picture siddsach  路  3Comments

fyubang picture fyubang  路  3Comments

yspaik picture yspaik  路  3Comments

ereday picture ereday  路  3Comments

lcswillems picture lcswillems  路  3Comments