I noticed that some users are pretty confused when reading source codes about variable attention_mask
like:
What is the meaning of Attention Mask #205
Clarifying attention mask #542
And I refer to the origional BERT repository - google-research/bert. Compared to the origin, I find in this repo sometimes the concepts of attention_mask and adder are mixed.
refering original BERT: ./modeling.py#L707
attention_mask = tf.expand_dims(attention_mask, axis=[1])
adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
attention_scores += adder
But in this repo: take src/transformers/modeling_tf_openai.py#L282 as an example:
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
attention_mask = tf.cast(attention_mask, tf.float32)
attention_mask = (1.0 - attention_mask) * -10000.0
and inside the method TFAttention._attn() src/transformers/modeling_tf_openai.py#L112:
if attention_mask is not None:
# Apply the attention mask
w = w + attention_mask
may be changing its name is way better?
like:
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
attention_mask = tf.cast(attention_mask, tf.float32)
adder = (1.0 - attention_mask) * -10000.0
and then:
if adder is not None:
# Apply the attention mask
attention_score = w + adder
I agree! Do you want to open a PR about this to change the naming? :-)
When doing this we just have to be careful to not change the user facing api when doing this -> which means that ideally, we should not rename any function arguments of high level modules like BertModel.forward().
I've created PR #4566 for this issue
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
Most helpful comment
I agree! Do you want to open a PR about this to change the naming? :-)
When doing this we just have to be careful to not change the user facing api when doing this -> which means that ideally, we should not rename any function arguments of high level modules like
BertModel.forward().