Hi there. I am attempting to extract the word embeddings that go into the encoder, similar to what's shown here. For that purpose I loaded my finetuned BART model and extracted the encoder embedded token weights with the below command:
>>> bart.state_dict()['model.encoder.embed_tokens.weight'].shape
torch.Size([50264, 1024])
I assume 50264 refers to the number of tokens in my dictionary. And 1024 is the max sequence length.
How can I map back the 50264 vectors to the corresponding tokens? So, ideally, I'd like to see {word_1: [val_1, val_2, ..., val_1024], ..., word_50264: [val_1, val_2, ..., val_1024]}, where _word_1_ is for example 'soup'.
Thanks!
Yes, 50264 is the number of tokens in the dictionary and 1024 is the embedding dimension.
For the main vocabulary (i.e., everything after the first 4 special symbols), you can use bart.decode to map them to the raw byte-pair encoded symbols:
embed = {
bart.decode(torch.tensor([i])): \
bart.state_dict()['model.encoder.embed_tokens.weight'][i]
for i in range(4, 10)
}
embed.keys() # dict_keys(['.', ' the', ',', ' to', ' and', ' of'])
Note that these are not always "words" in the traditional sense but byte-pair encoded (BPE) symbols. Because we use a byte-level BPE, they may not even be full unicode characters. For example:
print(bart.decode(torch.tensor([17])))
# �
Also, the first four symbols are beginning-of-sentence, pad, end-of-sentence and unknown. You can access them via bart.task.source_dictionary if you need:
print([bart.task.source_dictionary[i] for i in range(4)])
# ['<s>', '<pad>', '</s>', '<unk>']
Thanks @myleott . I'm interested in visualising the embedding array similar to what has been done here with word2vec: https://projector.tensorflow.org/.
Am I right that I should theoretically be able to use the below array X, standardize it, and visualise it in 2D space using PCA for example? Or is there something I'm missing?
>>> X = np.array(bart.state_dict()['model.encoder.embed_tokens.weight'])
>>> X.shape()
(50264, 1024)
Most helpful comment
Yes, 50264 is the number of tokens in the dictionary and 1024 is the embedding dimension.
For the main vocabulary (i.e., everything after the first 4 special symbols), you can use
bart.decodeto map them to the raw byte-pair encoded symbols:Note that these are not always "words" in the traditional sense but byte-pair encoded (BPE) symbols. Because we use a byte-level BPE, they may not even be full unicode characters. For example:
Also, the first four symbols are beginning-of-sentence, pad, end-of-sentence and unknown. You can access them via
bart.task.source_dictionaryif you need: