Transformers: How to encode a batch of sequence?

Created on 12 Mar 2020  Â·  1Comment  Â·  Source: huggingface/transformers

Hi,I am trying to learn this transformers package.
I prepared the data as following format:
(("Mary spends $20 on pizza"),("She likes eating it),("The pizza was great"))
I saw methods like tokenizer.encode,tokenizer.encode_plust and tokenizer.batch_encode_plus.However, the tokenizer.encode seems to only encode single sentence.
Because when I input the data below,the answer it gives are like this:

>>> d[0][0]
'John was writing lyrics for his new album'
>>> d[0][1]
'Franny did not particularly like all of the immigration happening'
>>> input_ids = torch.tensor(tokenizer.encode([d[0][0],d[0][1]]))
>>> input_ids
tensor([101, 100, 100, 102])

Obviously,this is not the rights answer for the encoding.
When I was try method tokenizer.encode_plust,it can't even work properly,as the document write

"text (str or List[str]) – The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the tokenize method) or a list of integers (tokenized string ids using the convert_tokens_to_ids method)"

It can't even work when I only input a single sentence:

>>> input_ids = torch.tensor(tokenizer.encode_plus(d[0][0]))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Could not infer dtype of dict

And the method,tokenizer.batch_encode_plust gives the same error message.

Most helpful comment

batch_encode_plus is the correct method :-)

from transformers import BertTokenizer
batch_input_str = (("Mary spends $20 on pizza"), ("She likes eating it"), ("The pizza was great"))
tok = BertTokenizer.from_pretrained('bert-base-uncased')
print(tok.batch_encode_plus(batch_input_str, pad_to_max_length=True))

>All comments

batch_encode_plus is the correct method :-)

from transformers import BertTokenizer
batch_input_str = (("Mary spends $20 on pizza"), ("She likes eating it"), ("The pizza was great"))
tok = BertTokenizer.from_pretrained('bert-base-uncased')
print(tok.batch_encode_plus(batch_input_str, pad_to_max_length=True))
Was this page helpful?
0 / 5 - 0 ratings