Hi,
This is probably me doing something wrong, but I can't get distilbert to give me a sensible prediciton when I mask part of a sentence.
This setup for BERT (based on the examples):
import logging
logging.basicConfig(level=logging.INFO)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
text = "Hello how are you doing?"
tokenized_text = tokenizer.tokenize(text)
masked_index = 2
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0]
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()
with torch.no_grad():
outputs = model(tokens_tensor, token_type_ids=segments_tensors)
predictions = outputs[0]
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
print(predicted_token)
gives the correct answer _are_ for _How are you doing?_.
But when I try the same with distilbert:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
text = "Hello how are you doing?"
tokenized_text = tokenizer.tokenize(text)
masked_index = 2
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
tokens_tensor = torch.tensor([indexed_tokens])
model = DistilBertModel.from_pretrained('distilbert-base-uncased')
model.eval()
with torch.no_grad():
# not adding/adding the segment tokens. when I give those to the model, it throws an error
last_hidden_states = model(tokens_tensor)
outputs = last_hidden_states[0]
predicted_index = torch.argmax(outputs[0], dim=1)[masked_index].item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
print(predicted_token)
I practically always get some _unusedxxx_ as a result. At first I thought this was because distilbert is a smaller model, but no matter what I try, I keep getting unused, so I am guessing it's something else.
Thanks in advance!
In full bert case you are using BertForMaskedLM but for distill bert you are using DistilBertModel which is not for masked language modelling. Try using DistilBertForMaskedLM. Check it, it works:
https://colab.research.google.com/drive/1GYt9H9QRUa5clFfAke6KPYl0mi4H1F3H
Well, in hindsight that was obvious. :) Thanks!
Most helpful comment
In full
bertcase you are usingBertForMaskedLMbut for distill bert you are usingDistilBertModelwhich is not for masked language modelling. Try usingDistilBertForMaskedLM. Check it, it works:https://colab.research.google.com/drive/1GYt9H9QRUa5clFfAke6KPYl0mi4H1F3H