In this commit Remove hard-coded uses of float32 to fix mixed precision use, the mixed precision issue is fixed for modeling_tf_bert.py.
However, for modeling_tf_distilbert.py, the line 171 is not fixed yet, and we get
173 embeddings = inputs_embeds + position_embeddings # (bs, max_seq_length, dim)
--> 174 embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
InvalidArgumentError: cannot compute AddV2 as input #1(zero-based) was expected to be a bfloat16 tensor but is a float tensor
while using mixed_bfloat16 mixed precision with TPU.
A very quick fix is the same as the fix for modeling_tf_bert.py:
position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)
@schmidek
Indeed! Do you want to open a PR to fix this?
@LysandreJik
I can do that. However @patrickvonplaten has already self-assigned for this. How do you think, @patrickvonplaten?
Hey @chiapas, it would be great if you can open a PR for it :-)
Hi @patrickvonplaten , OK, that would be my first contribution to transformers :)
Most helpful comment
Hi @patrickvonplaten , OK, that would be my first contribution to transformers :)