Transformers: Transformers Encoder and Decoder Inference

Created on 17 Dec 2019  路  10Comments  路  Source: huggingface/transformers

馃悰 Bug

Model I am using (Bert, XLNet....):

Language I am using the model on (English, Chinese....):

The problem arise when using:

  • [X ] the official example scripts: (give details)
  • [ ] my own modified scripts: (give details)

The tasks I am working on is:

  • [ ] an official GLUE/SQUaD task: (give the name)
  • [ X] my own task or dataset: (give details)

To Reproduce

Steps to reproduce the behavior:

  1. Error while doing inference.
from transformers import PreTrainedEncoderDecoder, BertTokenizer
model = PreTrainedEncoderDecoder.from_pretrained('bert-base-uncased',''bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
encoder_input_ids=tokenizer.encode("Hi How are you")
import torch
ouput = model(torch.tensor( encoder_input_ids).unsqueeze(0))

and the error is

TypeError: forward() missing 1 required positional argument: 'decoder_input_ids'

During inference why is decoder input is expected ?

Let me know if I'm missing anything?

Environment

OS: ubuntu
Python version: 3.6
PyTorch version:1.3.0
PyTorch Transformers version (or branch):2.2.0
Using GPU ? Yes
Distributed of parallel setup ? No
Any other relevant information:

Additional context

All 10 comments

As said in #2117 by @rlouf (an author of Transformers), at the moment you can use PreTrainedEncoderDecoder with only BERT model both as encoder and decoder.

In more details, he said: "_Indeed, as I specified in the article, PreTrainedEncoderDecoder only works with BERT as an encoder and BERT as a decoder. GPT2 shouldn't take too much work to adapt, but we haven't had the time to do it yet. Try PreTrainedEncoderDecoder.from_pretrained('bert-base-uncased', 'bert-base-uncased') should work. Let me know if it doesn't._".

Bug

Model I am using (Bert, XLNet....):

Language I am using the model on (English, Chinese....):

The problem arise when using:

  • [X ] the official example scripts: (give details)
  • [ ] my own modified scripts: (give details)

The tasks I am working on is:

  • [ ] an official GLUE/SQUaD task: (give the name)
  • [ X] my own task or dataset: (give details)

To Reproduce

Steps to reproduce the behavior:

  1. Error while doing inference.
from transformers import PreTrainedEncoderDecoder, BertTokenizer
model = PreTrainedEncoderDecoder.from_pretrained('bert-base-uncased','gpt2')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
encoder_input_ids=tokenizer.encode("Hi How are you")
import torch
ouput = model(torch.tensor( encoder_input_ids).unsqueeze(0))

and the error is

TypeError: forward() missing 1 required positional argument: 'decoder_input_ids'

During inference why is decoder input is expected ?

Let me know if I'm missing anything?

Environment

OS: ubuntu
Python version: 3.6
PyTorch version:1.3.0
PyTorch Transformers version (or branch):2.2.0
Using GPU ? Yes
Distributed of parallel setup ? No
Any other relevant information:

Additional context

@TheEdoardo93 it doesn't matter whether it is GPT2 or bert. Both has the same error :
I'm trying to play with GPT2 that's why I pasted my own code.

Using BERT as an Encoder and Decoder

>>> model = PreTrainedEncoderDecoder.from_pretrained('bert-base-uncased','bert-base-uncased')
>>> ouput = model(torch.tensor( encoder_input_ids).unsqueeze(0))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/guest_1/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'decoder_input_ids'

First of all, authors of Transformers are working on the implementation of PreTrainedEncoderDecoder object, so it's not a definitive implementation, e.g. the code lacks of the implementation of some methods. Said so, I've tested your code and I've revealed how to working with PreTrainedEncoderDecoder correctly without bugs. You can see my code below.

In brief, your problem occurs because you have not passed _all_ arguments necessary to the forward method. By looking at the source code here, you can see that this method accepts two parameters: encoder_input_ids and decoder_input_ids. In your code, you've passed _only one_ parameter, and the Python interpreter associates your encoder_input_ids to the encoder_input_ids of the forward method, but you don't have supply a value for decoder_input_ids of the forward method, and this is the cause that raise the error.

Python 3.6.9 |Anaconda, Inc.| (default, Jul 30 2019, 19:07:31) 
[GCC 7.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import transformers
/home/<user>/anaconda3/envs/huggingface/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/home/<user>/anaconda3/envs/huggingface/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/home/<user>/anaconda3/envs/huggingface/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/home/<user>/anaconda3/envs/huggingface/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/home/<user>/anaconda3/envs/huggingface/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/home/<user>/anaconda3/envs/huggingface/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])
>>> from transformers import PreTrainedEncoderDecoder
>>> from transformers import BertTokenizer
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
>>> model = PreTrainedEncoderDecoder.from_pretrained('bert-base-uncased', 'bert-base-uncased')
>>> text='Hi How are you'
>>> import torch
>>> input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0)
>>> input_ids
tensor([[ 101, 7632, 2129, 2024, 2017,  102]])
>>> output = model(input_ids) # YOUR PROBLEM IS HERE
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/<user>/anaconda3/envs/huggingface/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'decoder_input_ids'
>>> output = model(input_ids, input_ids) # SOLUTION TO YOUR PROBLEM
>>> output
(tensor([[[ -6.3390,  -6.3664,  -6.4600,  ...,  -5.5354,  -4.1787,  -5.8384],
         [ -6.3550,  -6.3077,  -6.4661,  ...,  -5.3516,  -4.1338,  -4.0742],
         [ -6.7090,  -6.6050,  -6.6682,  ...,  -5.9591,  -4.7142,  -3.8219],
         [ -7.7608,  -7.5956,  -7.6634,  ...,  -6.8113,  -5.7777,  -4.1638],
         [ -8.6462,  -8.5767,  -8.6366,  ...,  -7.9503,  -6.5382,  -5.0959],
         [-12.8752, -12.3775, -12.2770,  ..., -10.0880, -10.7659,  -9.0092]]],
       grad_fn=<AddBackward0>), tensor([[[ 0.0929, -0.0264, -0.1224,  ..., -0.2106,  0.1739,  0.1725],
         [ 0.4074, -0.0593,  0.5523,  ..., -0.6791,  0.6556, -0.2946],
         [-0.2116, -0.6859, -0.4628,  ...,  0.1528,  0.5977, -0.9102],
         [ 0.3992, -1.3208, -0.0801,  ..., -0.3213,  0.2557, -0.5780],
         [-0.0757, -1.3394,  0.1816,  ...,  0.0746,  0.4032, -0.7080],
         [ 0.5989, -0.2841, -0.3490,  ...,  0.3042, -0.4368, -0.2097]]],
       grad_fn=<NativeLayerNormBackward>), tensor([[-9.3097e-01, -3.3807e-01, -6.2162e-01,  8.4082e-01,  4.4154e-01,
         -1.5889e-01,  9.3273e-01,  2.2240e-01, -4.3249e-01, -9.9998e-01,
         -2.7810e-01,  8.9449e-01,  9.8638e-01,  6.4763e-02,  9.6649e-01,
         -7.7835e-01, -4.4046e-01, -5.9515e-01,  2.7585e-01, -7.4638e-01,
          7.4700e-01,  9.9983e-01,  4.4468e-01,  2.8673e-01,  3.6586e-01,
          9.7642e-01, -8.4343e-01,  9.6599e-01,  9.7235e-01,  7.2667e-01,
         -7.5785e-01,  9.2892e-02, -9.9089e-01, -1.7004e-01, -6.8200e-01,
         -9.9283e-01,  2.6244e-01, -7.9871e-01,  2.3397e-02,  4.6413e-02,
         -9.3371e-01,  2.7699e-01,  9.9995e-01, -3.2671e-01,  2.1108e-01,
         -2.0636e-01, -1.0000e+00,  1.9622e-01, -9.3330e-01,  6.8736e-01,
          6.4731e-01,  5.3773e-01,  9.2759e-02,  4.1069e-01,  4.0360e-01,
          1.9002e-01, -1.7049e-01,  7.5259e-03, -2.0453e-01, -5.7574e-01,
         -5.3062e-01,  3.9367e-01, -7.0627e-01, -9.2865e-01,  6.8820e-01,
          3.2698e-01, -3.3506e-02, -1.2323e-01, -1.5304e-01, -1.8077e-01,
          9.3398e-01,  2.6375e-01,  3.7505e-01, -8.9548e-01,  1.1777e-01,
          2.2054e-01, -6.3351e-01,  1.0000e+00, -6.9228e-01, -9.8653e-01,
          6.9799e-01,  4.0303e-01,  5.2453e-01,  2.3217e-01, -1.2151e-01,
         -1.0000e+00,  5.6760e-01,  2.9295e-02, -9.9318e-01,  8.3171e-02,
          5.2939e-01, -2.3176e-01, -1.5694e-01,  4.9278e-01, -4.2614e-01,
         -3.8079e-01, -2.6060e-01, -6.9055e-01, -1.7180e-01, -1.9810e-01,
         -2.7986e-02, -7.2085e-02, -3.7635e-01, -3.7743e-01,  1.3508e-01,
         -4.3892e-01, -6.1321e-01,  1.7726e-01, -3.5434e-01,  6.4734e-01,
          4.0373e-01, -2.8194e-01,  4.5104e-01, -9.7876e-01,  6.1044e-01,
         -2.3526e-01, -9.9035e-01, -5.1350e-01, -9.9280e-01,  6.8329e-01,
         -2.1623e-01, -1.4641e-01,  9.8273e-01,  3.7345e-01,  4.8171e-01,
         -5.6467e-03, -7.3005e-01, -1.0000e+00, -7.2252e-01, -5.1978e-01,
          7.0765e-02, -1.5036e-01, -9.8355e-01, -9.7384e-01,  5.8453e-01,
          9.6710e-01,  1.4193e-01,  9.9981e-01, -2.1194e-01,  9.6675e-01,
          2.3627e-02, -4.1555e-01,  1.9872e-01, -4.0593e-01,  6.5180e-01,
          6.1598e-01, -6.8750e-01,  7.9808e-02, -2.0437e-01,  3.4504e-01,
         -6.7176e-01, -1.3692e-01, -2.7750e-01, -9.6740e-01, -3.6698e-01,
          9.6934e-01, -2.5050e-01, -6.9297e-01,  4.8327e-01, -1.4613e-01,
         -5.1224e-01,  8.8387e-01,  6.9173e-01,  3.8395e-01, -1.7536e-01,
          3.8873e-01, -4.3011e-03,  6.1876e-01, -8.9292e-01,  3.4243e-02,
          4.5193e-01, -2.4782e-01, -4.7402e-01, -9.8375e-01, -3.1763e-01,
          5.9109e-01,  9.9284e-01,  7.9634e-01,  2.4601e-01,  6.1729e-01,
         -1.8376e-01,  6.8750e-01, -9.7083e-01,  9.8624e-01, -2.0573e-01,
          2.0418e-01,  4.1400e-01,  1.9102e-01, -9.1718e-01, -3.5273e-01,
          8.9628e-01, -5.6812e-01, -8.9552e-01, -3.5567e-02, -4.9052e-01,
         -4.3559e-01, -6.2323e-01,  5.6863e-01, -2.6201e-01, -3.1324e-01,
         -1.2852e-02,  9.4585e-01,  9.8664e-01,  8.3363e-01, -2.4392e-01,
          7.3786e-01, -9.4466e-01, -5.2720e-01, -1.6349e-02,  2.4207e-01,
          3.6905e-02,  9.9638e-01, -5.8095e-01, -7.2046e-02, -9.4418e-01,
         -9.8921e-01, -1.0289e-01, -9.3301e-01, -5.3531e-02, -6.8719e-01,
          5.3295e-01,  1.6390e-01,  2.3460e-01,  4.3260e-01, -9.9501e-01,
         -7.7318e-01,  2.6342e-01, -3.6949e-01,  4.0245e-01, -1.6657e-01,
          4.5766e-01,  7.4537e-01, -5.8549e-01,  8.4632e-01,  9.3526e-01,
         -6.4963e-01, -7.8264e-01,  8.5868e-01, -2.9683e-01,  9.0246e-01,
         -6.5124e-01,  9.8896e-01,  8.6732e-01,  8.7014e-01, -9.5627e-01,
         -4.1195e-01, -9.1043e-01, -4.5438e-01,  5.7729e-02, -3.6862e-01,
          5.7032e-01,  5.5757e-01,  3.0482e-01,  7.0850e-01, -6.6279e-01,
          9.9909e-01, -5.0139e-01, -9.7001e-01, -2.2370e-01, -1.8440e-02,
         -9.9107e-01,  7.2208e-01,  2.4379e-01,  6.9083e-02, -3.2313e-01,
         -7.3217e-01, -9.7295e-01,  9.2268e-01,  5.0675e-02,  9.9215e-01,
         -8.0247e-02, -9.5682e-01, -4.1637e-01, -9.4549e-01, -2.9790e-01,
         -1.5625e-01,  2.4707e-01, -1.8468e-01, -9.7276e-01,  4.7428e-01,
          5.6760e-01,  5.5919e-01, -1.9418e-01,  9.9932e-01,  1.0000e+00,
          9.7844e-01,  9.3669e-01,  9.5284e-01, -9.9929e-01, -4.9083e-01,
          9.9999e-01, -9.7835e-01, -1.0000e+00, -9.5292e-01, -6.5736e-01,
          4.1425e-01, -1.0000e+00, -2.7896e-03,  7.0756e-02, -9.4186e-01,
          2.7960e-01,  9.8389e-01,  9.9658e-01, -1.0000e+00,  8.8289e-01,
          9.6828e-01, -6.2958e-01,  9.3367e-01, -3.7519e-01,  9.8027e-01,
          4.2505e-01,  3.0766e-01, -3.2042e-01,  2.7469e-01, -7.8253e-01,
         -8.8309e-01, -1.5604e-01, -3.6222e-01,  9.9091e-01,  3.0116e-02,
         -7.8697e-01, -9.4496e-01,  2.2050e-01, -8.4521e-02, -4.8378e-01,
         -9.7952e-01, -1.3446e-01,  4.2209e-01,  7.8760e-01,  6.4992e-02,
          2.0492e-01, -7.8143e-01,  2.2120e-01, -5.0228e-01,  3.7149e-01,
          6.5244e-01, -9.4897e-01, -6.0978e-01, -4.8976e-03, -4.6856e-01,
         -2.8122e-01, -9.6984e-01,  9.8036e-01, -3.5220e-01,  7.4903e-01,
          1.0000e+00, -1.0373e-01, -9.4037e-01,  6.2856e-01,  1.5745e-01,
         -1.1596e-01,  1.0000e+00,  7.2891e-01, -9.8543e-01, -5.3814e-01,
          4.0543e-01, -4.9501e-01, -4.8527e-01,  9.9950e-01, -1.4058e-01,
         -2.3799e-01, -1.5841e-01,  9.8467e-01, -9.9180e-01,  9.7240e-01,
         -9.5292e-01, -9.8022e-01,  9.8012e-01,  9.5211e-01, -6.6387e-01,
         -7.2622e-01,  1.1509e-01, -3.1365e-01,  1.8487e-01, -9.7602e-01,
          7.9482e-01,  5.2428e-01, -1.3540e-01,  9.1377e-01, -9.0275e-01,
         -5.2769e-01,  2.8301e-01, -4.9215e-01,  3.4866e-02,  8.3573e-01,
          5.0270e-01, -2.4031e-01, -6.1194e-02, -2.5558e-01, -1.3530e-01,
         -9.8688e-01,  2.9877e-01,  1.0000e+00, -5.3199e-02,  4.4522e-01,
         -2.4564e-01,  3.8897e-02, -3.7170e-01,  3.6843e-01,  5.1087e-01,
         -1.9742e-01, -8.8481e-01,  4.5420e-01, -9.8222e-01, -9.8894e-01,
          8.5417e-01,  1.4674e-01, -3.3154e-01,  9.9999e-01,  4.4333e-01,
          7.1728e-02,  1.6790e-01,  9.6064e-01,  2.3267e-02,  7.2436e-01,
          4.9905e-01,  9.8528e-01, -2.0286e-01,  5.2711e-01,  9.0711e-01,
         -5.6147e-01, -3.4452e-01, -6.1113e-01, -8.1268e-02, -9.2887e-01,
          1.0119e-01, -9.7066e-01,  9.7404e-01,  8.2025e-01,  2.9760e-01,
          1.9059e-01,  3.6089e-01,  1.0000e+00, -2.7256e-01,  6.5052e-01,
         -6.0092e-01,  9.0897e-01, -9.9819e-01, -9.1409e-01, -3.7810e-01,
          1.2677e-02, -3.9492e-01, -3.0028e-01,  3.4323e-01, -9.7925e-01,
          4.4501e-01,  3.7582e-01, -9.9622e-01, -9.9495e-01,  1.6366e-01,
          9.2522e-01, -1.3063e-02, -9.5314e-01, -7.5003e-01, -6.5409e-01,
          4.1526e-01, -7.6235e-02, -9.6046e-01,  3.2395e-01, -2.7184e-01,
          4.7535e-01, -1.1767e-01,  5.6867e-01,  4.6844e-01,  8.3125e-01,
         -2.1505e-01, -2.6495e-01, -4.4479e-02, -8.5166e-01,  8.8927e-01,
         -8.9329e-01, -7.7919e-01, -1.5320e-01,  1.0000e+00, -4.3274e-01,
          6.4268e-01,  7.7000e-01,  7.9197e-01, -5.4889e-02,  8.0927e-02,
          7.9722e-01,  2.1034e-01, -1.9189e-01, -4.4749e-01, -8.0585e-01,
         -3.5409e-01,  7.0995e-01,  1.2411e-01,  2.0604e-01,  8.3328e-01,
          7.4750e-01,  1.7900e-04,  7.6917e-02, -9.1725e-02,  9.9981e-01,
         -2.6801e-01, -8.3787e-02, -4.8642e-01,  1.1836e-01, -3.5603e-01,
         -5.8620e-01,  1.0000e+00,  2.2691e-01,  3.2801e-01, -9.9343e-01,
         -7.3298e-01, -9.5126e-01,  1.0000e+00,  8.4895e-01, -8.6216e-01,
          6.9319e-01,  5.5441e-01, -1.1380e-02,  8.6958e-01, -1.2449e-01,
         -2.8602e-01,  1.8517e-01,  9.2221e-02,  9.6773e-01, -4.5911e-01,
         -9.7611e-01, -6.6894e-01,  3.7154e-01, -9.7862e-01,  9.9949e-01,
         -5.5391e-01, -2.0926e-01, -3.9404e-01, -2.3863e-02,  6.3624e-01,
         -1.0563e-01, -9.8927e-01, -1.4047e-01,  1.2247e-01,  9.7469e-01,
          2.6847e-01, -6.0451e-01, -9.5354e-01,  4.5191e-01,  6.6822e-01,
         -7.2218e-01, -9.6438e-01,  9.7538e-01, -9.9165e-01,  5.6641e-01,
          1.0000e+00,  2.2837e-01, -2.8539e-01,  1.6956e-01, -4.6714e-01,
          2.5561e-01, -2.6744e-01,  7.4301e-01, -9.7890e-01, -2.7469e-01,
         -1.4162e-01,  2.7886e-01, -7.0853e-02, -5.8891e-02,  8.2879e-01,
          1.9968e-01, -5.4085e-01, -6.8158e-01,  3.7584e-02,  3.5805e-01,
          8.9092e-01, -1.7879e-01, -8.1491e-02,  5.0655e-02, -7.9140e-02,
         -9.5114e-01, -1.4923e-01, -3.5370e-01, -9.9994e-01,  7.4321e-01,
         -1.0000e+00,  2.1850e-01, -2.5182e-01, -2.2171e-01,  8.7817e-01,
          2.9648e-01,  3.4926e-01, -8.2534e-01, -3.8831e-01,  7.6622e-01,
          8.0938e-01, -2.1051e-01, -3.0882e-01, -7.6183e-01,  2.2523e-01,
         -1.4952e-02,  1.5150e-01, -2.1056e-01,  7.3482e-01, -1.5207e-01,
          1.0000e+00,  1.0631e-01, -7.7462e-01, -9.8438e-01,  1.6242e-01,
         -1.6337e-01,  1.0000e+00, -9.4196e-01, -9.7149e-01,  3.9827e-01,
         -7.2371e-01, -8.6582e-01,  3.0937e-01, -6.4325e-02, -8.1062e-01,
         -8.8436e-01,  9.8219e-01,  9.3543e-01, -5.6058e-01,  4.5004e-01,
         -3.2933e-01, -5.5851e-01, -6.9835e-02,  6.0196e-01,  9.9111e-01,
          4.1170e-01,  9.1721e-01,  5.9978e-01, -9.4103e-02,  9.7966e-01,
          1.5322e-01,  5.3662e-01,  4.2338e-02,  1.0000e+00,  2.8920e-01,
         -9.3933e-01,  2.4383e-01, -9.8948e-01, -1.5036e-01, -9.7242e-01,
          2.8053e-01,  1.1691e-01,  9.0178e-01, -2.1055e-01,  9.7547e-01,
         -5.0734e-01, -8.5119e-03, -5.2189e-01,  1.1963e-01,  4.0313e-01,
         -9.4529e-01, -9.8752e-01, -9.8975e-01,  4.5711e-01, -4.0753e-01,
          5.8175e-02,  1.1543e-01,  8.6051e-02,  3.6199e-01,  4.3131e-01,
         -1.0000e+00,  9.5818e-01,  4.0499e-01,  6.9443e-01,  9.7521e-01,
          6.7153e-01,  4.3386e-01,  2.2481e-01, -9.9118e-01, -9.9126e-01,
         -3.1248e-01, -1.4604e-01,  7.9951e-01,  6.1145e-01,  9.2726e-01,
          4.0171e-01, -3.9375e-01, -2.0938e-01, -3.2651e-02, -4.1723e-01,
         -9.9582e-01,  4.5682e-01, -9.4401e-02, -9.8150e-01,  9.6766e-01,
         -5.5518e-01, -8.0481e-02,  4.4743e-01, -6.0429e-01,  9.7261e-01,
          8.6633e-01,  3.7309e-01,  9.4917e-04,  4.6426e-01,  9.1590e-01,
          9.6965e-01,  9.8799e-01, -4.6592e-01,  8.7146e-01, -3.1116e-01,
          5.1496e-01,  6.7961e-01, -9.5609e-01,  1.3302e-03,  3.6581e-01,
         -2.3789e-01,  2.6341e-01, -1.2874e-01, -9.8464e-01,  4.8621e-01,
         -1.8921e-01,  6.1015e-01, -4.3986e-01,  2.1561e-01, -3.7115e-01,
         -1.5832e-02, -6.9704e-01, -7.3403e-01,  5.7310e-01,  5.0895e-01,
          9.4111e-01,  6.9365e-01,  6.9171e-02, -7.3277e-01, -1.1294e-01,
         -4.0168e-01, -9.2587e-01,  9.6638e-01,  2.2207e-02,  1.5029e-01,
          2.8954e-01, -8.5994e-02,  7.4631e-01, -1.5933e-01, -3.5710e-01,
         -1.6201e-01, -7.1149e-01,  9.0602e-01, -4.2873e-01, -4.6653e-01,
         -5.4765e-01,  7.4640e-01,  2.3966e-01,  9.9982e-01, -4.6795e-01,
         -6.4802e-01, -4.1201e-01, -3.4984e-01,  3.5475e-01, -5.4668e-01,
         -1.0000e+00,  3.6903e-01, -1.7324e-01,  4.3267e-01, -4.7206e-01,
          6.3586e-01, -5.2151e-01, -9.9077e-01, -1.6597e-01,  2.6735e-01,
          4.5069e-01, -4.3034e-01, -5.6321e-01,  5.7792e-01,  8.8123e-02,
          9.4964e-01,  9.2798e-01, -3.3326e-01,  5.1963e-01,  6.0865e-01,
         -4.4019e-01, -6.8129e-01,  9.3489e-01]], grad_fn=<TanhBackward>))
>>> len(output)
3
>>> output[0]
tensor([[[ -6.3390,  -6.3664,  -6.4600,  ...,  -5.5354,  -4.1787,  -5.8384],
         [ -6.3550,  -6.3077,  -6.4661,  ...,  -5.3516,  -4.1338,  -4.0742],
         [ -6.7090,  -6.6050,  -6.6682,  ...,  -5.9591,  -4.7142,  -3.8219],
         [ -7.7608,  -7.5956,  -7.6634,  ...,  -6.8113,  -5.7777,  -4.1638],
         [ -8.6462,  -8.5767,  -8.6366,  ...,  -7.9503,  -6.5382,  -5.0959],
         [-12.8752, -12.3775, -12.2770,  ..., -10.0880, -10.7659,  -9.0092]]],
       grad_fn=<AddBackward0>)
>>> output[0].shape
torch.Size([1, 6, 30522])
>>> output[1]
tensor([[[ 0.0929, -0.0264, -0.1224,  ..., -0.2106,  0.1739,  0.1725],
         [ 0.4074, -0.0593,  0.5523,  ..., -0.6791,  0.6556, -0.2946],
         [-0.2116, -0.6859, -0.4628,  ...,  0.1528,  0.5977, -0.9102],
         [ 0.3992, -1.3208, -0.0801,  ..., -0.3213,  0.2557, -0.5780],
         [-0.0757, -1.3394,  0.1816,  ...,  0.0746,  0.4032, -0.7080],
         [ 0.5989, -0.2841, -0.3490,  ...,  0.3042, -0.4368, -0.2097]]],
       grad_fn=<NativeLayerNormBackward>)
>>> output[1].shape
torch.Size([1, 6, 768])
>>> output[2]
tensor([[-9.3097e-01, -3.3807e-01, -6.2162e-01,  8.4082e-01,  4.4154e-01,
         -1.5889e-01,  9.3273e-01,  2.2240e-01, -4.3249e-01, -9.9998e-01,
         -2.7810e-01,  8.9449e-01,  9.8638e-01,  6.4763e-02,  9.6649e-01,
         -7.7835e-01, -4.4046e-01, -5.9515e-01,  2.7585e-01, -7.4638e-01,
          7.4700e-01,  9.9983e-01,  4.4468e-01,  2.8673e-01,  3.6586e-01,
          9.7642e-01, -8.4343e-01,  9.6599e-01,  9.7235e-01,  7.2667e-01,
         -7.5785e-01,  9.2892e-02, -9.9089e-01, -1.7004e-01, -6.8200e-01,
         -9.9283e-01,  2.6244e-01, -7.9871e-01,  2.3397e-02,  4.6413e-02,
         -9.3371e-01,  2.7699e-01,  9.9995e-01, -3.2671e-01,  2.1108e-01,
         -2.0636e-01, -1.0000e+00,  1.9622e-01, -9.3330e-01,  6.8736e-01,
          6.4731e-01,  5.3773e-01,  9.2759e-02,  4.1069e-01,  4.0360e-01,
          1.9002e-01, -1.7049e-01,  7.5259e-03, -2.0453e-01, -5.7574e-01,
         -5.3062e-01,  3.9367e-01, -7.0627e-01, -9.2865e-01,  6.8820e-01,
          3.2698e-01, -3.3506e-02, -1.2323e-01, -1.5304e-01, -1.8077e-01,
          9.3398e-01,  2.6375e-01,  3.7505e-01, -8.9548e-01,  1.1777e-01,
          2.2054e-01, -6.3351e-01,  1.0000e+00, -6.9228e-01, -9.8653e-01,
          6.9799e-01,  4.0303e-01,  5.2453e-01,  2.3217e-01, -1.2151e-01,
         -1.0000e+00,  5.6760e-01,  2.9295e-02, -9.9318e-01,  8.3171e-02,
          5.2939e-01, -2.3176e-01, -1.5694e-01,  4.9278e-01, -4.2614e-01,
         -3.8079e-01, -2.6060e-01, -6.9055e-01, -1.7180e-01, -1.9810e-01,
         -2.7986e-02, -7.2085e-02, -3.7635e-01, -3.7743e-01,  1.3508e-01,
         -4.3892e-01, -6.1321e-01,  1.7726e-01, -3.5434e-01,  6.4734e-01,
          4.0373e-01, -2.8194e-01,  4.5104e-01, -9.7876e-01,  6.1044e-01,
         -2.3526e-01, -9.9035e-01, -5.1350e-01, -9.9280e-01,  6.8329e-01,
         -2.1623e-01, -1.4641e-01,  9.8273e-01,  3.7345e-01,  4.8171e-01,
         -5.6467e-03, -7.3005e-01, -1.0000e+00, -7.2252e-01, -5.1978e-01,
          7.0765e-02, -1.5036e-01, -9.8355e-01, -9.7384e-01,  5.8453e-01,
          9.6710e-01,  1.4193e-01,  9.9981e-01, -2.1194e-01,  9.6675e-01,
          2.3627e-02, -4.1555e-01,  1.9872e-01, -4.0593e-01,  6.5180e-01,
          6.1598e-01, -6.8750e-01,  7.9808e-02, -2.0437e-01,  3.4504e-01,
         -6.7176e-01, -1.3692e-01, -2.7750e-01, -9.6740e-01, -3.6698e-01,
          9.6934e-01, -2.5050e-01, -6.9297e-01,  4.8327e-01, -1.4613e-01,
         -5.1224e-01,  8.8387e-01,  6.9173e-01,  3.8395e-01, -1.7536e-01,
          3.8873e-01, -4.3011e-03,  6.1876e-01, -8.9292e-01,  3.4243e-02,
          4.5193e-01, -2.4782e-01, -4.7402e-01, -9.8375e-01, -3.1763e-01,
          5.9109e-01,  9.9284e-01,  7.9634e-01,  2.4601e-01,  6.1729e-01,
         -1.8376e-01,  6.8750e-01, -9.7083e-01,  9.8624e-01, -2.0573e-01,
          2.0418e-01,  4.1400e-01,  1.9102e-01, -9.1718e-01, -3.5273e-01,
          8.9628e-01, -5.6812e-01, -8.9552e-01, -3.5567e-02, -4.9052e-01,
         -4.3559e-01, -6.2323e-01,  5.6863e-01, -2.6201e-01, -3.1324e-01,
         -1.2852e-02,  9.4585e-01,  9.8664e-01,  8.3363e-01, -2.4392e-01,
          7.3786e-01, -9.4466e-01, -5.2720e-01, -1.6349e-02,  2.4207e-01,
          3.6905e-02,  9.9638e-01, -5.8095e-01, -7.2046e-02, -9.4418e-01,
         -9.8921e-01, -1.0289e-01, -9.3301e-01, -5.3531e-02, -6.8719e-01,
          5.3295e-01,  1.6390e-01,  2.3460e-01,  4.3260e-01, -9.9501e-01,
         -7.7318e-01,  2.6342e-01, -3.6949e-01,  4.0245e-01, -1.6657e-01,
          4.5766e-01,  7.4537e-01, -5.8549e-01,  8.4632e-01,  9.3526e-01,
         -6.4963e-01, -7.8264e-01,  8.5868e-01, -2.9683e-01,  9.0246e-01,
         -6.5124e-01,  9.8896e-01,  8.6732e-01,  8.7014e-01, -9.5627e-01,
         -4.1195e-01, -9.1043e-01, -4.5438e-01,  5.7729e-02, -3.6862e-01,
          5.7032e-01,  5.5757e-01,  3.0482e-01,  7.0850e-01, -6.6279e-01,
          9.9909e-01, -5.0139e-01, -9.7001e-01, -2.2370e-01, -1.8440e-02,
         -9.9107e-01,  7.2208e-01,  2.4379e-01,  6.9083e-02, -3.2313e-01,
         -7.3217e-01, -9.7295e-01,  9.2268e-01,  5.0675e-02,  9.9215e-01,
         -8.0247e-02, -9.5682e-01, -4.1637e-01, -9.4549e-01, -2.9790e-01,
         -1.5625e-01,  2.4707e-01, -1.8468e-01, -9.7276e-01,  4.7428e-01,
          5.6760e-01,  5.5919e-01, -1.9418e-01,  9.9932e-01,  1.0000e+00,
          9.7844e-01,  9.3669e-01,  9.5284e-01, -9.9929e-01, -4.9083e-01,
          9.9999e-01, -9.7835e-01, -1.0000e+00, -9.5292e-01, -6.5736e-01,
          4.1425e-01, -1.0000e+00, -2.7896e-03,  7.0756e-02, -9.4186e-01,
          2.7960e-01,  9.8389e-01,  9.9658e-01, -1.0000e+00,  8.8289e-01,
          9.6828e-01, -6.2958e-01,  9.3367e-01, -3.7519e-01,  9.8027e-01,
          4.2505e-01,  3.0766e-01, -3.2042e-01,  2.7469e-01, -7.8253e-01,
         -8.8309e-01, -1.5604e-01, -3.6222e-01,  9.9091e-01,  3.0116e-02,
         -7.8697e-01, -9.4496e-01,  2.2050e-01, -8.4521e-02, -4.8378e-01,
         -9.7952e-01, -1.3446e-01,  4.2209e-01,  7.8760e-01,  6.4992e-02,
          2.0492e-01, -7.8143e-01,  2.2120e-01, -5.0228e-01,  3.7149e-01,
          6.5244e-01, -9.4897e-01, -6.0978e-01, -4.8976e-03, -4.6856e-01,
         -2.8122e-01, -9.6984e-01,  9.8036e-01, -3.5220e-01,  7.4903e-01,
          1.0000e+00, -1.0373e-01, -9.4037e-01,  6.2856e-01,  1.5745e-01,
         -1.1596e-01,  1.0000e+00,  7.2891e-01, -9.8543e-01, -5.3814e-01,
          4.0543e-01, -4.9501e-01, -4.8527e-01,  9.9950e-01, -1.4058e-01,
         -2.3799e-01, -1.5841e-01,  9.8467e-01, -9.9180e-01,  9.7240e-01,
         -9.5292e-01, -9.8022e-01,  9.8012e-01,  9.5211e-01, -6.6387e-01,
         -7.2622e-01,  1.1509e-01, -3.1365e-01,  1.8487e-01, -9.7602e-01,
          7.9482e-01,  5.2428e-01, -1.3540e-01,  9.1377e-01, -9.0275e-01,
         -5.2769e-01,  2.8301e-01, -4.9215e-01,  3.4866e-02,  8.3573e-01,
          5.0270e-01, -2.4031e-01, -6.1194e-02, -2.5558e-01, -1.3530e-01,
         -9.8688e-01,  2.9877e-01,  1.0000e+00, -5.3199e-02,  4.4522e-01,
         -2.4564e-01,  3.8897e-02, -3.7170e-01,  3.6843e-01,  5.1087e-01,
         -1.9742e-01, -8.8481e-01,  4.5420e-01, -9.8222e-01, -9.8894e-01,
          8.5417e-01,  1.4674e-01, -3.3154e-01,  9.9999e-01,  4.4333e-01,
          7.1728e-02,  1.6790e-01,  9.6064e-01,  2.3267e-02,  7.2436e-01,
          4.9905e-01,  9.8528e-01, -2.0286e-01,  5.2711e-01,  9.0711e-01,
         -5.6147e-01, -3.4452e-01, -6.1113e-01, -8.1268e-02, -9.2887e-01,
          1.0119e-01, -9.7066e-01,  9.7404e-01,  8.2025e-01,  2.9760e-01,
          1.9059e-01,  3.6089e-01,  1.0000e+00, -2.7256e-01,  6.5052e-01,
         -6.0092e-01,  9.0897e-01, -9.9819e-01, -9.1409e-01, -3.7810e-01,
          1.2677e-02, -3.9492e-01, -3.0028e-01,  3.4323e-01, -9.7925e-01,
          4.4501e-01,  3.7582e-01, -9.9622e-01, -9.9495e-01,  1.6366e-01,
          9.2522e-01, -1.3063e-02, -9.5314e-01, -7.5003e-01, -6.5409e-01,
          4.1526e-01, -7.6235e-02, -9.6046e-01,  3.2395e-01, -2.7184e-01,
          4.7535e-01, -1.1767e-01,  5.6867e-01,  4.6844e-01,  8.3125e-01,
         -2.1505e-01, -2.6495e-01, -4.4479e-02, -8.5166e-01,  8.8927e-01,
         -8.9329e-01, -7.7919e-01, -1.5320e-01,  1.0000e+00, -4.3274e-01,
          6.4268e-01,  7.7000e-01,  7.9197e-01, -5.4889e-02,  8.0927e-02,
          7.9722e-01,  2.1034e-01, -1.9189e-01, -4.4749e-01, -8.0585e-01,
         -3.5409e-01,  7.0995e-01,  1.2411e-01,  2.0604e-01,  8.3328e-01,
          7.4750e-01,  1.7900e-04,  7.6917e-02, -9.1725e-02,  9.9981e-01,
         -2.6801e-01, -8.3787e-02, -4.8642e-01,  1.1836e-01, -3.5603e-01,
         -5.8620e-01,  1.0000e+00,  2.2691e-01,  3.2801e-01, -9.9343e-01,
         -7.3298e-01, -9.5126e-01,  1.0000e+00,  8.4895e-01, -8.6216e-01,
          6.9319e-01,  5.5441e-01, -1.1380e-02,  8.6958e-01, -1.2449e-01,
         -2.8602e-01,  1.8517e-01,  9.2221e-02,  9.6773e-01, -4.5911e-01,
         -9.7611e-01, -6.6894e-01,  3.7154e-01, -9.7862e-01,  9.9949e-01,
         -5.5391e-01, -2.0926e-01, -3.9404e-01, -2.3863e-02,  6.3624e-01,
         -1.0563e-01, -9.8927e-01, -1.4047e-01,  1.2247e-01,  9.7469e-01,
          2.6847e-01, -6.0451e-01, -9.5354e-01,  4.5191e-01,  6.6822e-01,
         -7.2218e-01, -9.6438e-01,  9.7538e-01, -9.9165e-01,  5.6641e-01,
          1.0000e+00,  2.2837e-01, -2.8539e-01,  1.6956e-01, -4.6714e-01,
          2.5561e-01, -2.6744e-01,  7.4301e-01, -9.7890e-01, -2.7469e-01,
         -1.4162e-01,  2.7886e-01, -7.0853e-02, -5.8891e-02,  8.2879e-01,
          1.9968e-01, -5.4085e-01, -6.8158e-01,  3.7584e-02,  3.5805e-01,
          8.9092e-01, -1.7879e-01, -8.1491e-02,  5.0655e-02, -7.9140e-02,
         -9.5114e-01, -1.4923e-01, -3.5370e-01, -9.9994e-01,  7.4321e-01,
         -1.0000e+00,  2.1850e-01, -2.5182e-01, -2.2171e-01,  8.7817e-01,
          2.9648e-01,  3.4926e-01, -8.2534e-01, -3.8831e-01,  7.6622e-01,
          8.0938e-01, -2.1051e-01, -3.0882e-01, -7.6183e-01,  2.2523e-01,
         -1.4952e-02,  1.5150e-01, -2.1056e-01,  7.3482e-01, -1.5207e-01,
          1.0000e+00,  1.0631e-01, -7.7462e-01, -9.8438e-01,  1.6242e-01,
         -1.6337e-01,  1.0000e+00, -9.4196e-01, -9.7149e-01,  3.9827e-01,
         -7.2371e-01, -8.6582e-01,  3.0937e-01, -6.4325e-02, -8.1062e-01,
         -8.8436e-01,  9.8219e-01,  9.3543e-01, -5.6058e-01,  4.5004e-01,
         -3.2933e-01, -5.5851e-01, -6.9835e-02,  6.0196e-01,  9.9111e-01,
          4.1170e-01,  9.1721e-01,  5.9978e-01, -9.4103e-02,  9.7966e-01,
          1.5322e-01,  5.3662e-01,  4.2338e-02,  1.0000e+00,  2.8920e-01,
         -9.3933e-01,  2.4383e-01, -9.8948e-01, -1.5036e-01, -9.7242e-01,
          2.8053e-01,  1.1691e-01,  9.0178e-01, -2.1055e-01,  9.7547e-01,
         -5.0734e-01, -8.5119e-03, -5.2189e-01,  1.1963e-01,  4.0313e-01,
         -9.4529e-01, -9.8752e-01, -9.8975e-01,  4.5711e-01, -4.0753e-01,
          5.8175e-02,  1.1543e-01,  8.6051e-02,  3.6199e-01,  4.3131e-01,
         -1.0000e+00,  9.5818e-01,  4.0499e-01,  6.9443e-01,  9.7521e-01,
          6.7153e-01,  4.3386e-01,  2.2481e-01, -9.9118e-01, -9.9126e-01,
         -3.1248e-01, -1.4604e-01,  7.9951e-01,  6.1145e-01,  9.2726e-01,
          4.0171e-01, -3.9375e-01, -2.0938e-01, -3.2651e-02, -4.1723e-01,
         -9.9582e-01,  4.5682e-01, -9.4401e-02, -9.8150e-01,  9.6766e-01,
         -5.5518e-01, -8.0481e-02,  4.4743e-01, -6.0429e-01,  9.7261e-01,
          8.6633e-01,  3.7309e-01,  9.4917e-04,  4.6426e-01,  9.1590e-01,
          9.6965e-01,  9.8799e-01, -4.6592e-01,  8.7146e-01, -3.1116e-01,
          5.1496e-01,  6.7961e-01, -9.5609e-01,  1.3302e-03,  3.6581e-01,
         -2.3789e-01,  2.6341e-01, -1.2874e-01, -9.8464e-01,  4.8621e-01,
         -1.8921e-01,  6.1015e-01, -4.3986e-01,  2.1561e-01, -3.7115e-01,
         -1.5832e-02, -6.9704e-01, -7.3403e-01,  5.7310e-01,  5.0895e-01,
          9.4111e-01,  6.9365e-01,  6.9171e-02, -7.3277e-01, -1.1294e-01,
         -4.0168e-01, -9.2587e-01,  9.6638e-01,  2.2207e-02,  1.5029e-01,
          2.8954e-01, -8.5994e-02,  7.4631e-01, -1.5933e-01, -3.5710e-01,
         -1.6201e-01, -7.1149e-01,  9.0602e-01, -4.2873e-01, -4.6653e-01,
         -5.4765e-01,  7.4640e-01,  2.3966e-01,  9.9982e-01, -4.6795e-01,
         -6.4802e-01, -4.1201e-01, -3.4984e-01,  3.5475e-01, -5.4668e-01,
         -1.0000e+00,  3.6903e-01, -1.7324e-01,  4.3267e-01, -4.7206e-01,
          6.3586e-01, -5.2151e-01, -9.9077e-01, -1.6597e-01,  2.6735e-01,
          4.5069e-01, -4.3034e-01, -5.6321e-01,  5.7792e-01,  8.8123e-02,
          9.4964e-01,  9.2798e-01, -3.3326e-01,  5.1963e-01,  6.0865e-01,
         -4.4019e-01, -6.8129e-01,  9.3489e-01]], grad_fn=<TanhBackward>)
>>> output[2].shape
torch.Size([1, 768])
>>>

@TheEdoardo93 it doesn't matter whether it is GPT2 or bert. Both has the same error :
I'm trying to play with GPT2 that's why I pasted my own code.

Using BERT as an Encoder and Decoder

>>> model = PreTrainedEncoderDecoder.from_pretrained('bert-base-uncased','bert-base-uncased')
>>> ouput = model(torch.tensor( encoder_input_ids).unsqueeze(0))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/guest_1/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'decoder_input_ids'

I got the same issue

Do you try to follow my suggestion reported above? In my environment, it works as expected.
My environment details:

  • OS: Ubuntu 16.04
  • Python: 3.6.9
  • Transformers: 2.2.2 (installed with pip install transformers)
  • PyTorch: 1.3.1
  • TensorFlow: 2.0

If not, can you post your environment and a list of steps to reproduce the bug?

I got the same issue

Said so, I've tested your code and I've revealed how to working with PreTrainedEncoderDecoder correctly without bugs. You can see my code below.

@TheEdoardo93 that doesn't make sense you giving your encoders input as a decoder's input.

Said so, I've tested your code and I've revealed how to working with PreTrainedEncoderDecoder correctly without bugs. You can see my code below.

@TheEdoardo93 that doesn't make sense your giving your encoders input as a decoder's input.
Never-mind, I know what's the issue is so closing it.

Sorry, it was my mistake. Can you share with us what was the problem and how to solve it?

@anandhperumal Could you please share how you solved the issue? did you path a <BOS> to decoder input?
Appreciate that

@anandhperumal can you let us know how you fixed the issue?

You can pass a start token to the decoder, just like a Seq2Seq Arch.

Was this page helpful?
0 / 5 - 0 ratings