Model I am using (Bert, XLNet....):
Language I am using the model on (English, Chinese....):
The problem arise when using:
The tasks I am working on is:
Steps to reproduce the behavior:
from transformers import PreTrainedEncoderDecoder, BertTokenizer
model = PreTrainedEncoderDecoder.from_pretrained('bert-base-uncased',''bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
encoder_input_ids=tokenizer.encode("Hi How are you")
import torch
ouput = model(torch.tensor( encoder_input_ids).unsqueeze(0))
and the error is
TypeError: forward() missing 1 required positional argument: 'decoder_input_ids'
During inference why is decoder input is expected ?
Let me know if I'm missing anything?
OS: ubuntu
Python version: 3.6
PyTorch version:1.3.0
PyTorch Transformers version (or branch):2.2.0
Using GPU ? Yes
Distributed of parallel setup ? No
Any other relevant information:
As said in #2117 by @rlouf (an author of Transformers), at the moment you can use PreTrainedEncoderDecoder with only BERT model both as encoder and decoder.
In more details, he said: "_Indeed, as I specified in the article, PreTrainedEncoderDecoder only works with BERT as an encoder and BERT as a decoder. GPT2 shouldn't take too much work to adapt, but we haven't had the time to do it yet. Try PreTrainedEncoderDecoder.from_pretrained('bert-base-uncased', 'bert-base-uncased') should work. Let me know if it doesn't._".
Bug
Model I am using (Bert, XLNet....):
Language I am using the model on (English, Chinese....):
The problem arise when using:
- [X ] the official example scripts: (give details)
- [ ] my own modified scripts: (give details)
The tasks I am working on is:
- [ ] an official GLUE/SQUaD task: (give the name)
- [ X] my own task or dataset: (give details)
To Reproduce
Steps to reproduce the behavior:
- Error while doing inference.
from transformers import PreTrainedEncoderDecoder, BertTokenizer model = PreTrainedEncoderDecoder.from_pretrained('bert-base-uncased','gpt2') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') encoder_input_ids=tokenizer.encode("Hi How are you") import torch ouput = model(torch.tensor( encoder_input_ids).unsqueeze(0))and the error is
TypeError: forward() missing 1 required positional argument: 'decoder_input_ids'During inference why is decoder input is expected ?
Let me know if I'm missing anything?
Environment
OS: ubuntu
Python version: 3.6
PyTorch version:1.3.0
PyTorch Transformers version (or branch):2.2.0
Using GPU ? Yes
Distributed of parallel setup ? No
Any other relevant information:Additional context
@TheEdoardo93 it doesn't matter whether it is GPT2 or bert. Both has the same error :
I'm trying to play with GPT2 that's why I pasted my own code.
Using BERT as an Encoder and Decoder
>>> model = PreTrainedEncoderDecoder.from_pretrained('bert-base-uncased','bert-base-uncased')
>>> ouput = model(torch.tensor( encoder_input_ids).unsqueeze(0))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/guest_1/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'decoder_input_ids'
First of all, authors of Transformers are working on the implementation of PreTrainedEncoderDecoder object, so it's not a definitive implementation, e.g. the code lacks of the implementation of some methods. Said so, I've tested your code and I've revealed how to working with PreTrainedEncoderDecoder correctly without bugs. You can see my code below.
In brief, your problem occurs because you have not passed _all_ arguments necessary to the forward method. By looking at the source code here, you can see that this method accepts two parameters: encoder_input_ids and decoder_input_ids. In your code, you've passed _only one_ parameter, and the Python interpreter associates your encoder_input_ids to the encoder_input_ids of the forward method, but you don't have supply a value for decoder_input_ids of the forward method, and this is the cause that raise the error.
Python 3.6.9 |Anaconda, Inc.| (default, Jul 30 2019, 19:07:31)
[GCC 7.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import transformers
/home/<user>/anaconda3/envs/huggingface/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint8 = np.dtype([("qint8", np.int8, 1)])
/home/<user>/anaconda3/envs/huggingface/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/home/<user>/anaconda3/envs/huggingface/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint16 = np.dtype([("qint16", np.int16, 1)])
/home/<user>/anaconda3/envs/huggingface/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/home/<user>/anaconda3/envs/huggingface/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint32 = np.dtype([("qint32", np.int32, 1)])
/home/<user>/anaconda3/envs/huggingface/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
np_resource = np.dtype([("resource", np.ubyte, 1)])
>>> from transformers import PreTrainedEncoderDecoder
>>> from transformers import BertTokenizer
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
>>> model = PreTrainedEncoderDecoder.from_pretrained('bert-base-uncased', 'bert-base-uncased')
>>> text='Hi How are you'
>>> import torch
>>> input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0)
>>> input_ids
tensor([[ 101, 7632, 2129, 2024, 2017, 102]])
>>> output = model(input_ids) # YOUR PROBLEM IS HERE
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/<user>/anaconda3/envs/huggingface/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'decoder_input_ids'
>>> output = model(input_ids, input_ids) # SOLUTION TO YOUR PROBLEM
>>> output
(tensor([[[ -6.3390, -6.3664, -6.4600, ..., -5.5354, -4.1787, -5.8384],
[ -6.3550, -6.3077, -6.4661, ..., -5.3516, -4.1338, -4.0742],
[ -6.7090, -6.6050, -6.6682, ..., -5.9591, -4.7142, -3.8219],
[ -7.7608, -7.5956, -7.6634, ..., -6.8113, -5.7777, -4.1638],
[ -8.6462, -8.5767, -8.6366, ..., -7.9503, -6.5382, -5.0959],
[-12.8752, -12.3775, -12.2770, ..., -10.0880, -10.7659, -9.0092]]],
grad_fn=<AddBackward0>), tensor([[[ 0.0929, -0.0264, -0.1224, ..., -0.2106, 0.1739, 0.1725],
[ 0.4074, -0.0593, 0.5523, ..., -0.6791, 0.6556, -0.2946],
[-0.2116, -0.6859, -0.4628, ..., 0.1528, 0.5977, -0.9102],
[ 0.3992, -1.3208, -0.0801, ..., -0.3213, 0.2557, -0.5780],
[-0.0757, -1.3394, 0.1816, ..., 0.0746, 0.4032, -0.7080],
[ 0.5989, -0.2841, -0.3490, ..., 0.3042, -0.4368, -0.2097]]],
grad_fn=<NativeLayerNormBackward>), tensor([[-9.3097e-01, -3.3807e-01, -6.2162e-01, 8.4082e-01, 4.4154e-01,
-1.5889e-01, 9.3273e-01, 2.2240e-01, -4.3249e-01, -9.9998e-01,
-2.7810e-01, 8.9449e-01, 9.8638e-01, 6.4763e-02, 9.6649e-01,
-7.7835e-01, -4.4046e-01, -5.9515e-01, 2.7585e-01, -7.4638e-01,
7.4700e-01, 9.9983e-01, 4.4468e-01, 2.8673e-01, 3.6586e-01,
9.7642e-01, -8.4343e-01, 9.6599e-01, 9.7235e-01, 7.2667e-01,
-7.5785e-01, 9.2892e-02, -9.9089e-01, -1.7004e-01, -6.8200e-01,
-9.9283e-01, 2.6244e-01, -7.9871e-01, 2.3397e-02, 4.6413e-02,
-9.3371e-01, 2.7699e-01, 9.9995e-01, -3.2671e-01, 2.1108e-01,
-2.0636e-01, -1.0000e+00, 1.9622e-01, -9.3330e-01, 6.8736e-01,
6.4731e-01, 5.3773e-01, 9.2759e-02, 4.1069e-01, 4.0360e-01,
1.9002e-01, -1.7049e-01, 7.5259e-03, -2.0453e-01, -5.7574e-01,
-5.3062e-01, 3.9367e-01, -7.0627e-01, -9.2865e-01, 6.8820e-01,
3.2698e-01, -3.3506e-02, -1.2323e-01, -1.5304e-01, -1.8077e-01,
9.3398e-01, 2.6375e-01, 3.7505e-01, -8.9548e-01, 1.1777e-01,
2.2054e-01, -6.3351e-01, 1.0000e+00, -6.9228e-01, -9.8653e-01,
6.9799e-01, 4.0303e-01, 5.2453e-01, 2.3217e-01, -1.2151e-01,
-1.0000e+00, 5.6760e-01, 2.9295e-02, -9.9318e-01, 8.3171e-02,
5.2939e-01, -2.3176e-01, -1.5694e-01, 4.9278e-01, -4.2614e-01,
-3.8079e-01, -2.6060e-01, -6.9055e-01, -1.7180e-01, -1.9810e-01,
-2.7986e-02, -7.2085e-02, -3.7635e-01, -3.7743e-01, 1.3508e-01,
-4.3892e-01, -6.1321e-01, 1.7726e-01, -3.5434e-01, 6.4734e-01,
4.0373e-01, -2.8194e-01, 4.5104e-01, -9.7876e-01, 6.1044e-01,
-2.3526e-01, -9.9035e-01, -5.1350e-01, -9.9280e-01, 6.8329e-01,
-2.1623e-01, -1.4641e-01, 9.8273e-01, 3.7345e-01, 4.8171e-01,
-5.6467e-03, -7.3005e-01, -1.0000e+00, -7.2252e-01, -5.1978e-01,
7.0765e-02, -1.5036e-01, -9.8355e-01, -9.7384e-01, 5.8453e-01,
9.6710e-01, 1.4193e-01, 9.9981e-01, -2.1194e-01, 9.6675e-01,
2.3627e-02, -4.1555e-01, 1.9872e-01, -4.0593e-01, 6.5180e-01,
6.1598e-01, -6.8750e-01, 7.9808e-02, -2.0437e-01, 3.4504e-01,
-6.7176e-01, -1.3692e-01, -2.7750e-01, -9.6740e-01, -3.6698e-01,
9.6934e-01, -2.5050e-01, -6.9297e-01, 4.8327e-01, -1.4613e-01,
-5.1224e-01, 8.8387e-01, 6.9173e-01, 3.8395e-01, -1.7536e-01,
3.8873e-01, -4.3011e-03, 6.1876e-01, -8.9292e-01, 3.4243e-02,
4.5193e-01, -2.4782e-01, -4.7402e-01, -9.8375e-01, -3.1763e-01,
5.9109e-01, 9.9284e-01, 7.9634e-01, 2.4601e-01, 6.1729e-01,
-1.8376e-01, 6.8750e-01, -9.7083e-01, 9.8624e-01, -2.0573e-01,
2.0418e-01, 4.1400e-01, 1.9102e-01, -9.1718e-01, -3.5273e-01,
8.9628e-01, -5.6812e-01, -8.9552e-01, -3.5567e-02, -4.9052e-01,
-4.3559e-01, -6.2323e-01, 5.6863e-01, -2.6201e-01, -3.1324e-01,
-1.2852e-02, 9.4585e-01, 9.8664e-01, 8.3363e-01, -2.4392e-01,
7.3786e-01, -9.4466e-01, -5.2720e-01, -1.6349e-02, 2.4207e-01,
3.6905e-02, 9.9638e-01, -5.8095e-01, -7.2046e-02, -9.4418e-01,
-9.8921e-01, -1.0289e-01, -9.3301e-01, -5.3531e-02, -6.8719e-01,
5.3295e-01, 1.6390e-01, 2.3460e-01, 4.3260e-01, -9.9501e-01,
-7.7318e-01, 2.6342e-01, -3.6949e-01, 4.0245e-01, -1.6657e-01,
4.5766e-01, 7.4537e-01, -5.8549e-01, 8.4632e-01, 9.3526e-01,
-6.4963e-01, -7.8264e-01, 8.5868e-01, -2.9683e-01, 9.0246e-01,
-6.5124e-01, 9.8896e-01, 8.6732e-01, 8.7014e-01, -9.5627e-01,
-4.1195e-01, -9.1043e-01, -4.5438e-01, 5.7729e-02, -3.6862e-01,
5.7032e-01, 5.5757e-01, 3.0482e-01, 7.0850e-01, -6.6279e-01,
9.9909e-01, -5.0139e-01, -9.7001e-01, -2.2370e-01, -1.8440e-02,
-9.9107e-01, 7.2208e-01, 2.4379e-01, 6.9083e-02, -3.2313e-01,
-7.3217e-01, -9.7295e-01, 9.2268e-01, 5.0675e-02, 9.9215e-01,
-8.0247e-02, -9.5682e-01, -4.1637e-01, -9.4549e-01, -2.9790e-01,
-1.5625e-01, 2.4707e-01, -1.8468e-01, -9.7276e-01, 4.7428e-01,
5.6760e-01, 5.5919e-01, -1.9418e-01, 9.9932e-01, 1.0000e+00,
9.7844e-01, 9.3669e-01, 9.5284e-01, -9.9929e-01, -4.9083e-01,
9.9999e-01, -9.7835e-01, -1.0000e+00, -9.5292e-01, -6.5736e-01,
4.1425e-01, -1.0000e+00, -2.7896e-03, 7.0756e-02, -9.4186e-01,
2.7960e-01, 9.8389e-01, 9.9658e-01, -1.0000e+00, 8.8289e-01,
9.6828e-01, -6.2958e-01, 9.3367e-01, -3.7519e-01, 9.8027e-01,
4.2505e-01, 3.0766e-01, -3.2042e-01, 2.7469e-01, -7.8253e-01,
-8.8309e-01, -1.5604e-01, -3.6222e-01, 9.9091e-01, 3.0116e-02,
-7.8697e-01, -9.4496e-01, 2.2050e-01, -8.4521e-02, -4.8378e-01,
-9.7952e-01, -1.3446e-01, 4.2209e-01, 7.8760e-01, 6.4992e-02,
2.0492e-01, -7.8143e-01, 2.2120e-01, -5.0228e-01, 3.7149e-01,
6.5244e-01, -9.4897e-01, -6.0978e-01, -4.8976e-03, -4.6856e-01,
-2.8122e-01, -9.6984e-01, 9.8036e-01, -3.5220e-01, 7.4903e-01,
1.0000e+00, -1.0373e-01, -9.4037e-01, 6.2856e-01, 1.5745e-01,
-1.1596e-01, 1.0000e+00, 7.2891e-01, -9.8543e-01, -5.3814e-01,
4.0543e-01, -4.9501e-01, -4.8527e-01, 9.9950e-01, -1.4058e-01,
-2.3799e-01, -1.5841e-01, 9.8467e-01, -9.9180e-01, 9.7240e-01,
-9.5292e-01, -9.8022e-01, 9.8012e-01, 9.5211e-01, -6.6387e-01,
-7.2622e-01, 1.1509e-01, -3.1365e-01, 1.8487e-01, -9.7602e-01,
7.9482e-01, 5.2428e-01, -1.3540e-01, 9.1377e-01, -9.0275e-01,
-5.2769e-01, 2.8301e-01, -4.9215e-01, 3.4866e-02, 8.3573e-01,
5.0270e-01, -2.4031e-01, -6.1194e-02, -2.5558e-01, -1.3530e-01,
-9.8688e-01, 2.9877e-01, 1.0000e+00, -5.3199e-02, 4.4522e-01,
-2.4564e-01, 3.8897e-02, -3.7170e-01, 3.6843e-01, 5.1087e-01,
-1.9742e-01, -8.8481e-01, 4.5420e-01, -9.8222e-01, -9.8894e-01,
8.5417e-01, 1.4674e-01, -3.3154e-01, 9.9999e-01, 4.4333e-01,
7.1728e-02, 1.6790e-01, 9.6064e-01, 2.3267e-02, 7.2436e-01,
4.9905e-01, 9.8528e-01, -2.0286e-01, 5.2711e-01, 9.0711e-01,
-5.6147e-01, -3.4452e-01, -6.1113e-01, -8.1268e-02, -9.2887e-01,
1.0119e-01, -9.7066e-01, 9.7404e-01, 8.2025e-01, 2.9760e-01,
1.9059e-01, 3.6089e-01, 1.0000e+00, -2.7256e-01, 6.5052e-01,
-6.0092e-01, 9.0897e-01, -9.9819e-01, -9.1409e-01, -3.7810e-01,
1.2677e-02, -3.9492e-01, -3.0028e-01, 3.4323e-01, -9.7925e-01,
4.4501e-01, 3.7582e-01, -9.9622e-01, -9.9495e-01, 1.6366e-01,
9.2522e-01, -1.3063e-02, -9.5314e-01, -7.5003e-01, -6.5409e-01,
4.1526e-01, -7.6235e-02, -9.6046e-01, 3.2395e-01, -2.7184e-01,
4.7535e-01, -1.1767e-01, 5.6867e-01, 4.6844e-01, 8.3125e-01,
-2.1505e-01, -2.6495e-01, -4.4479e-02, -8.5166e-01, 8.8927e-01,
-8.9329e-01, -7.7919e-01, -1.5320e-01, 1.0000e+00, -4.3274e-01,
6.4268e-01, 7.7000e-01, 7.9197e-01, -5.4889e-02, 8.0927e-02,
7.9722e-01, 2.1034e-01, -1.9189e-01, -4.4749e-01, -8.0585e-01,
-3.5409e-01, 7.0995e-01, 1.2411e-01, 2.0604e-01, 8.3328e-01,
7.4750e-01, 1.7900e-04, 7.6917e-02, -9.1725e-02, 9.9981e-01,
-2.6801e-01, -8.3787e-02, -4.8642e-01, 1.1836e-01, -3.5603e-01,
-5.8620e-01, 1.0000e+00, 2.2691e-01, 3.2801e-01, -9.9343e-01,
-7.3298e-01, -9.5126e-01, 1.0000e+00, 8.4895e-01, -8.6216e-01,
6.9319e-01, 5.5441e-01, -1.1380e-02, 8.6958e-01, -1.2449e-01,
-2.8602e-01, 1.8517e-01, 9.2221e-02, 9.6773e-01, -4.5911e-01,
-9.7611e-01, -6.6894e-01, 3.7154e-01, -9.7862e-01, 9.9949e-01,
-5.5391e-01, -2.0926e-01, -3.9404e-01, -2.3863e-02, 6.3624e-01,
-1.0563e-01, -9.8927e-01, -1.4047e-01, 1.2247e-01, 9.7469e-01,
2.6847e-01, -6.0451e-01, -9.5354e-01, 4.5191e-01, 6.6822e-01,
-7.2218e-01, -9.6438e-01, 9.7538e-01, -9.9165e-01, 5.6641e-01,
1.0000e+00, 2.2837e-01, -2.8539e-01, 1.6956e-01, -4.6714e-01,
2.5561e-01, -2.6744e-01, 7.4301e-01, -9.7890e-01, -2.7469e-01,
-1.4162e-01, 2.7886e-01, -7.0853e-02, -5.8891e-02, 8.2879e-01,
1.9968e-01, -5.4085e-01, -6.8158e-01, 3.7584e-02, 3.5805e-01,
8.9092e-01, -1.7879e-01, -8.1491e-02, 5.0655e-02, -7.9140e-02,
-9.5114e-01, -1.4923e-01, -3.5370e-01, -9.9994e-01, 7.4321e-01,
-1.0000e+00, 2.1850e-01, -2.5182e-01, -2.2171e-01, 8.7817e-01,
2.9648e-01, 3.4926e-01, -8.2534e-01, -3.8831e-01, 7.6622e-01,
8.0938e-01, -2.1051e-01, -3.0882e-01, -7.6183e-01, 2.2523e-01,
-1.4952e-02, 1.5150e-01, -2.1056e-01, 7.3482e-01, -1.5207e-01,
1.0000e+00, 1.0631e-01, -7.7462e-01, -9.8438e-01, 1.6242e-01,
-1.6337e-01, 1.0000e+00, -9.4196e-01, -9.7149e-01, 3.9827e-01,
-7.2371e-01, -8.6582e-01, 3.0937e-01, -6.4325e-02, -8.1062e-01,
-8.8436e-01, 9.8219e-01, 9.3543e-01, -5.6058e-01, 4.5004e-01,
-3.2933e-01, -5.5851e-01, -6.9835e-02, 6.0196e-01, 9.9111e-01,
4.1170e-01, 9.1721e-01, 5.9978e-01, -9.4103e-02, 9.7966e-01,
1.5322e-01, 5.3662e-01, 4.2338e-02, 1.0000e+00, 2.8920e-01,
-9.3933e-01, 2.4383e-01, -9.8948e-01, -1.5036e-01, -9.7242e-01,
2.8053e-01, 1.1691e-01, 9.0178e-01, -2.1055e-01, 9.7547e-01,
-5.0734e-01, -8.5119e-03, -5.2189e-01, 1.1963e-01, 4.0313e-01,
-9.4529e-01, -9.8752e-01, -9.8975e-01, 4.5711e-01, -4.0753e-01,
5.8175e-02, 1.1543e-01, 8.6051e-02, 3.6199e-01, 4.3131e-01,
-1.0000e+00, 9.5818e-01, 4.0499e-01, 6.9443e-01, 9.7521e-01,
6.7153e-01, 4.3386e-01, 2.2481e-01, -9.9118e-01, -9.9126e-01,
-3.1248e-01, -1.4604e-01, 7.9951e-01, 6.1145e-01, 9.2726e-01,
4.0171e-01, -3.9375e-01, -2.0938e-01, -3.2651e-02, -4.1723e-01,
-9.9582e-01, 4.5682e-01, -9.4401e-02, -9.8150e-01, 9.6766e-01,
-5.5518e-01, -8.0481e-02, 4.4743e-01, -6.0429e-01, 9.7261e-01,
8.6633e-01, 3.7309e-01, 9.4917e-04, 4.6426e-01, 9.1590e-01,
9.6965e-01, 9.8799e-01, -4.6592e-01, 8.7146e-01, -3.1116e-01,
5.1496e-01, 6.7961e-01, -9.5609e-01, 1.3302e-03, 3.6581e-01,
-2.3789e-01, 2.6341e-01, -1.2874e-01, -9.8464e-01, 4.8621e-01,
-1.8921e-01, 6.1015e-01, -4.3986e-01, 2.1561e-01, -3.7115e-01,
-1.5832e-02, -6.9704e-01, -7.3403e-01, 5.7310e-01, 5.0895e-01,
9.4111e-01, 6.9365e-01, 6.9171e-02, -7.3277e-01, -1.1294e-01,
-4.0168e-01, -9.2587e-01, 9.6638e-01, 2.2207e-02, 1.5029e-01,
2.8954e-01, -8.5994e-02, 7.4631e-01, -1.5933e-01, -3.5710e-01,
-1.6201e-01, -7.1149e-01, 9.0602e-01, -4.2873e-01, -4.6653e-01,
-5.4765e-01, 7.4640e-01, 2.3966e-01, 9.9982e-01, -4.6795e-01,
-6.4802e-01, -4.1201e-01, -3.4984e-01, 3.5475e-01, -5.4668e-01,
-1.0000e+00, 3.6903e-01, -1.7324e-01, 4.3267e-01, -4.7206e-01,
6.3586e-01, -5.2151e-01, -9.9077e-01, -1.6597e-01, 2.6735e-01,
4.5069e-01, -4.3034e-01, -5.6321e-01, 5.7792e-01, 8.8123e-02,
9.4964e-01, 9.2798e-01, -3.3326e-01, 5.1963e-01, 6.0865e-01,
-4.4019e-01, -6.8129e-01, 9.3489e-01]], grad_fn=<TanhBackward>))
>>> len(output)
3
>>> output[0]
tensor([[[ -6.3390, -6.3664, -6.4600, ..., -5.5354, -4.1787, -5.8384],
[ -6.3550, -6.3077, -6.4661, ..., -5.3516, -4.1338, -4.0742],
[ -6.7090, -6.6050, -6.6682, ..., -5.9591, -4.7142, -3.8219],
[ -7.7608, -7.5956, -7.6634, ..., -6.8113, -5.7777, -4.1638],
[ -8.6462, -8.5767, -8.6366, ..., -7.9503, -6.5382, -5.0959],
[-12.8752, -12.3775, -12.2770, ..., -10.0880, -10.7659, -9.0092]]],
grad_fn=<AddBackward0>)
>>> output[0].shape
torch.Size([1, 6, 30522])
>>> output[1]
tensor([[[ 0.0929, -0.0264, -0.1224, ..., -0.2106, 0.1739, 0.1725],
[ 0.4074, -0.0593, 0.5523, ..., -0.6791, 0.6556, -0.2946],
[-0.2116, -0.6859, -0.4628, ..., 0.1528, 0.5977, -0.9102],
[ 0.3992, -1.3208, -0.0801, ..., -0.3213, 0.2557, -0.5780],
[-0.0757, -1.3394, 0.1816, ..., 0.0746, 0.4032, -0.7080],
[ 0.5989, -0.2841, -0.3490, ..., 0.3042, -0.4368, -0.2097]]],
grad_fn=<NativeLayerNormBackward>)
>>> output[1].shape
torch.Size([1, 6, 768])
>>> output[2]
tensor([[-9.3097e-01, -3.3807e-01, -6.2162e-01, 8.4082e-01, 4.4154e-01,
-1.5889e-01, 9.3273e-01, 2.2240e-01, -4.3249e-01, -9.9998e-01,
-2.7810e-01, 8.9449e-01, 9.8638e-01, 6.4763e-02, 9.6649e-01,
-7.7835e-01, -4.4046e-01, -5.9515e-01, 2.7585e-01, -7.4638e-01,
7.4700e-01, 9.9983e-01, 4.4468e-01, 2.8673e-01, 3.6586e-01,
9.7642e-01, -8.4343e-01, 9.6599e-01, 9.7235e-01, 7.2667e-01,
-7.5785e-01, 9.2892e-02, -9.9089e-01, -1.7004e-01, -6.8200e-01,
-9.9283e-01, 2.6244e-01, -7.9871e-01, 2.3397e-02, 4.6413e-02,
-9.3371e-01, 2.7699e-01, 9.9995e-01, -3.2671e-01, 2.1108e-01,
-2.0636e-01, -1.0000e+00, 1.9622e-01, -9.3330e-01, 6.8736e-01,
6.4731e-01, 5.3773e-01, 9.2759e-02, 4.1069e-01, 4.0360e-01,
1.9002e-01, -1.7049e-01, 7.5259e-03, -2.0453e-01, -5.7574e-01,
-5.3062e-01, 3.9367e-01, -7.0627e-01, -9.2865e-01, 6.8820e-01,
3.2698e-01, -3.3506e-02, -1.2323e-01, -1.5304e-01, -1.8077e-01,
9.3398e-01, 2.6375e-01, 3.7505e-01, -8.9548e-01, 1.1777e-01,
2.2054e-01, -6.3351e-01, 1.0000e+00, -6.9228e-01, -9.8653e-01,
6.9799e-01, 4.0303e-01, 5.2453e-01, 2.3217e-01, -1.2151e-01,
-1.0000e+00, 5.6760e-01, 2.9295e-02, -9.9318e-01, 8.3171e-02,
5.2939e-01, -2.3176e-01, -1.5694e-01, 4.9278e-01, -4.2614e-01,
-3.8079e-01, -2.6060e-01, -6.9055e-01, -1.7180e-01, -1.9810e-01,
-2.7986e-02, -7.2085e-02, -3.7635e-01, -3.7743e-01, 1.3508e-01,
-4.3892e-01, -6.1321e-01, 1.7726e-01, -3.5434e-01, 6.4734e-01,
4.0373e-01, -2.8194e-01, 4.5104e-01, -9.7876e-01, 6.1044e-01,
-2.3526e-01, -9.9035e-01, -5.1350e-01, -9.9280e-01, 6.8329e-01,
-2.1623e-01, -1.4641e-01, 9.8273e-01, 3.7345e-01, 4.8171e-01,
-5.6467e-03, -7.3005e-01, -1.0000e+00, -7.2252e-01, -5.1978e-01,
7.0765e-02, -1.5036e-01, -9.8355e-01, -9.7384e-01, 5.8453e-01,
9.6710e-01, 1.4193e-01, 9.9981e-01, -2.1194e-01, 9.6675e-01,
2.3627e-02, -4.1555e-01, 1.9872e-01, -4.0593e-01, 6.5180e-01,
6.1598e-01, -6.8750e-01, 7.9808e-02, -2.0437e-01, 3.4504e-01,
-6.7176e-01, -1.3692e-01, -2.7750e-01, -9.6740e-01, -3.6698e-01,
9.6934e-01, -2.5050e-01, -6.9297e-01, 4.8327e-01, -1.4613e-01,
-5.1224e-01, 8.8387e-01, 6.9173e-01, 3.8395e-01, -1.7536e-01,
3.8873e-01, -4.3011e-03, 6.1876e-01, -8.9292e-01, 3.4243e-02,
4.5193e-01, -2.4782e-01, -4.7402e-01, -9.8375e-01, -3.1763e-01,
5.9109e-01, 9.9284e-01, 7.9634e-01, 2.4601e-01, 6.1729e-01,
-1.8376e-01, 6.8750e-01, -9.7083e-01, 9.8624e-01, -2.0573e-01,
2.0418e-01, 4.1400e-01, 1.9102e-01, -9.1718e-01, -3.5273e-01,
8.9628e-01, -5.6812e-01, -8.9552e-01, -3.5567e-02, -4.9052e-01,
-4.3559e-01, -6.2323e-01, 5.6863e-01, -2.6201e-01, -3.1324e-01,
-1.2852e-02, 9.4585e-01, 9.8664e-01, 8.3363e-01, -2.4392e-01,
7.3786e-01, -9.4466e-01, -5.2720e-01, -1.6349e-02, 2.4207e-01,
3.6905e-02, 9.9638e-01, -5.8095e-01, -7.2046e-02, -9.4418e-01,
-9.8921e-01, -1.0289e-01, -9.3301e-01, -5.3531e-02, -6.8719e-01,
5.3295e-01, 1.6390e-01, 2.3460e-01, 4.3260e-01, -9.9501e-01,
-7.7318e-01, 2.6342e-01, -3.6949e-01, 4.0245e-01, -1.6657e-01,
4.5766e-01, 7.4537e-01, -5.8549e-01, 8.4632e-01, 9.3526e-01,
-6.4963e-01, -7.8264e-01, 8.5868e-01, -2.9683e-01, 9.0246e-01,
-6.5124e-01, 9.8896e-01, 8.6732e-01, 8.7014e-01, -9.5627e-01,
-4.1195e-01, -9.1043e-01, -4.5438e-01, 5.7729e-02, -3.6862e-01,
5.7032e-01, 5.5757e-01, 3.0482e-01, 7.0850e-01, -6.6279e-01,
9.9909e-01, -5.0139e-01, -9.7001e-01, -2.2370e-01, -1.8440e-02,
-9.9107e-01, 7.2208e-01, 2.4379e-01, 6.9083e-02, -3.2313e-01,
-7.3217e-01, -9.7295e-01, 9.2268e-01, 5.0675e-02, 9.9215e-01,
-8.0247e-02, -9.5682e-01, -4.1637e-01, -9.4549e-01, -2.9790e-01,
-1.5625e-01, 2.4707e-01, -1.8468e-01, -9.7276e-01, 4.7428e-01,
5.6760e-01, 5.5919e-01, -1.9418e-01, 9.9932e-01, 1.0000e+00,
9.7844e-01, 9.3669e-01, 9.5284e-01, -9.9929e-01, -4.9083e-01,
9.9999e-01, -9.7835e-01, -1.0000e+00, -9.5292e-01, -6.5736e-01,
4.1425e-01, -1.0000e+00, -2.7896e-03, 7.0756e-02, -9.4186e-01,
2.7960e-01, 9.8389e-01, 9.9658e-01, -1.0000e+00, 8.8289e-01,
9.6828e-01, -6.2958e-01, 9.3367e-01, -3.7519e-01, 9.8027e-01,
4.2505e-01, 3.0766e-01, -3.2042e-01, 2.7469e-01, -7.8253e-01,
-8.8309e-01, -1.5604e-01, -3.6222e-01, 9.9091e-01, 3.0116e-02,
-7.8697e-01, -9.4496e-01, 2.2050e-01, -8.4521e-02, -4.8378e-01,
-9.7952e-01, -1.3446e-01, 4.2209e-01, 7.8760e-01, 6.4992e-02,
2.0492e-01, -7.8143e-01, 2.2120e-01, -5.0228e-01, 3.7149e-01,
6.5244e-01, -9.4897e-01, -6.0978e-01, -4.8976e-03, -4.6856e-01,
-2.8122e-01, -9.6984e-01, 9.8036e-01, -3.5220e-01, 7.4903e-01,
1.0000e+00, -1.0373e-01, -9.4037e-01, 6.2856e-01, 1.5745e-01,
-1.1596e-01, 1.0000e+00, 7.2891e-01, -9.8543e-01, -5.3814e-01,
4.0543e-01, -4.9501e-01, -4.8527e-01, 9.9950e-01, -1.4058e-01,
-2.3799e-01, -1.5841e-01, 9.8467e-01, -9.9180e-01, 9.7240e-01,
-9.5292e-01, -9.8022e-01, 9.8012e-01, 9.5211e-01, -6.6387e-01,
-7.2622e-01, 1.1509e-01, -3.1365e-01, 1.8487e-01, -9.7602e-01,
7.9482e-01, 5.2428e-01, -1.3540e-01, 9.1377e-01, -9.0275e-01,
-5.2769e-01, 2.8301e-01, -4.9215e-01, 3.4866e-02, 8.3573e-01,
5.0270e-01, -2.4031e-01, -6.1194e-02, -2.5558e-01, -1.3530e-01,
-9.8688e-01, 2.9877e-01, 1.0000e+00, -5.3199e-02, 4.4522e-01,
-2.4564e-01, 3.8897e-02, -3.7170e-01, 3.6843e-01, 5.1087e-01,
-1.9742e-01, -8.8481e-01, 4.5420e-01, -9.8222e-01, -9.8894e-01,
8.5417e-01, 1.4674e-01, -3.3154e-01, 9.9999e-01, 4.4333e-01,
7.1728e-02, 1.6790e-01, 9.6064e-01, 2.3267e-02, 7.2436e-01,
4.9905e-01, 9.8528e-01, -2.0286e-01, 5.2711e-01, 9.0711e-01,
-5.6147e-01, -3.4452e-01, -6.1113e-01, -8.1268e-02, -9.2887e-01,
1.0119e-01, -9.7066e-01, 9.7404e-01, 8.2025e-01, 2.9760e-01,
1.9059e-01, 3.6089e-01, 1.0000e+00, -2.7256e-01, 6.5052e-01,
-6.0092e-01, 9.0897e-01, -9.9819e-01, -9.1409e-01, -3.7810e-01,
1.2677e-02, -3.9492e-01, -3.0028e-01, 3.4323e-01, -9.7925e-01,
4.4501e-01, 3.7582e-01, -9.9622e-01, -9.9495e-01, 1.6366e-01,
9.2522e-01, -1.3063e-02, -9.5314e-01, -7.5003e-01, -6.5409e-01,
4.1526e-01, -7.6235e-02, -9.6046e-01, 3.2395e-01, -2.7184e-01,
4.7535e-01, -1.1767e-01, 5.6867e-01, 4.6844e-01, 8.3125e-01,
-2.1505e-01, -2.6495e-01, -4.4479e-02, -8.5166e-01, 8.8927e-01,
-8.9329e-01, -7.7919e-01, -1.5320e-01, 1.0000e+00, -4.3274e-01,
6.4268e-01, 7.7000e-01, 7.9197e-01, -5.4889e-02, 8.0927e-02,
7.9722e-01, 2.1034e-01, -1.9189e-01, -4.4749e-01, -8.0585e-01,
-3.5409e-01, 7.0995e-01, 1.2411e-01, 2.0604e-01, 8.3328e-01,
7.4750e-01, 1.7900e-04, 7.6917e-02, -9.1725e-02, 9.9981e-01,
-2.6801e-01, -8.3787e-02, -4.8642e-01, 1.1836e-01, -3.5603e-01,
-5.8620e-01, 1.0000e+00, 2.2691e-01, 3.2801e-01, -9.9343e-01,
-7.3298e-01, -9.5126e-01, 1.0000e+00, 8.4895e-01, -8.6216e-01,
6.9319e-01, 5.5441e-01, -1.1380e-02, 8.6958e-01, -1.2449e-01,
-2.8602e-01, 1.8517e-01, 9.2221e-02, 9.6773e-01, -4.5911e-01,
-9.7611e-01, -6.6894e-01, 3.7154e-01, -9.7862e-01, 9.9949e-01,
-5.5391e-01, -2.0926e-01, -3.9404e-01, -2.3863e-02, 6.3624e-01,
-1.0563e-01, -9.8927e-01, -1.4047e-01, 1.2247e-01, 9.7469e-01,
2.6847e-01, -6.0451e-01, -9.5354e-01, 4.5191e-01, 6.6822e-01,
-7.2218e-01, -9.6438e-01, 9.7538e-01, -9.9165e-01, 5.6641e-01,
1.0000e+00, 2.2837e-01, -2.8539e-01, 1.6956e-01, -4.6714e-01,
2.5561e-01, -2.6744e-01, 7.4301e-01, -9.7890e-01, -2.7469e-01,
-1.4162e-01, 2.7886e-01, -7.0853e-02, -5.8891e-02, 8.2879e-01,
1.9968e-01, -5.4085e-01, -6.8158e-01, 3.7584e-02, 3.5805e-01,
8.9092e-01, -1.7879e-01, -8.1491e-02, 5.0655e-02, -7.9140e-02,
-9.5114e-01, -1.4923e-01, -3.5370e-01, -9.9994e-01, 7.4321e-01,
-1.0000e+00, 2.1850e-01, -2.5182e-01, -2.2171e-01, 8.7817e-01,
2.9648e-01, 3.4926e-01, -8.2534e-01, -3.8831e-01, 7.6622e-01,
8.0938e-01, -2.1051e-01, -3.0882e-01, -7.6183e-01, 2.2523e-01,
-1.4952e-02, 1.5150e-01, -2.1056e-01, 7.3482e-01, -1.5207e-01,
1.0000e+00, 1.0631e-01, -7.7462e-01, -9.8438e-01, 1.6242e-01,
-1.6337e-01, 1.0000e+00, -9.4196e-01, -9.7149e-01, 3.9827e-01,
-7.2371e-01, -8.6582e-01, 3.0937e-01, -6.4325e-02, -8.1062e-01,
-8.8436e-01, 9.8219e-01, 9.3543e-01, -5.6058e-01, 4.5004e-01,
-3.2933e-01, -5.5851e-01, -6.9835e-02, 6.0196e-01, 9.9111e-01,
4.1170e-01, 9.1721e-01, 5.9978e-01, -9.4103e-02, 9.7966e-01,
1.5322e-01, 5.3662e-01, 4.2338e-02, 1.0000e+00, 2.8920e-01,
-9.3933e-01, 2.4383e-01, -9.8948e-01, -1.5036e-01, -9.7242e-01,
2.8053e-01, 1.1691e-01, 9.0178e-01, -2.1055e-01, 9.7547e-01,
-5.0734e-01, -8.5119e-03, -5.2189e-01, 1.1963e-01, 4.0313e-01,
-9.4529e-01, -9.8752e-01, -9.8975e-01, 4.5711e-01, -4.0753e-01,
5.8175e-02, 1.1543e-01, 8.6051e-02, 3.6199e-01, 4.3131e-01,
-1.0000e+00, 9.5818e-01, 4.0499e-01, 6.9443e-01, 9.7521e-01,
6.7153e-01, 4.3386e-01, 2.2481e-01, -9.9118e-01, -9.9126e-01,
-3.1248e-01, -1.4604e-01, 7.9951e-01, 6.1145e-01, 9.2726e-01,
4.0171e-01, -3.9375e-01, -2.0938e-01, -3.2651e-02, -4.1723e-01,
-9.9582e-01, 4.5682e-01, -9.4401e-02, -9.8150e-01, 9.6766e-01,
-5.5518e-01, -8.0481e-02, 4.4743e-01, -6.0429e-01, 9.7261e-01,
8.6633e-01, 3.7309e-01, 9.4917e-04, 4.6426e-01, 9.1590e-01,
9.6965e-01, 9.8799e-01, -4.6592e-01, 8.7146e-01, -3.1116e-01,
5.1496e-01, 6.7961e-01, -9.5609e-01, 1.3302e-03, 3.6581e-01,
-2.3789e-01, 2.6341e-01, -1.2874e-01, -9.8464e-01, 4.8621e-01,
-1.8921e-01, 6.1015e-01, -4.3986e-01, 2.1561e-01, -3.7115e-01,
-1.5832e-02, -6.9704e-01, -7.3403e-01, 5.7310e-01, 5.0895e-01,
9.4111e-01, 6.9365e-01, 6.9171e-02, -7.3277e-01, -1.1294e-01,
-4.0168e-01, -9.2587e-01, 9.6638e-01, 2.2207e-02, 1.5029e-01,
2.8954e-01, -8.5994e-02, 7.4631e-01, -1.5933e-01, -3.5710e-01,
-1.6201e-01, -7.1149e-01, 9.0602e-01, -4.2873e-01, -4.6653e-01,
-5.4765e-01, 7.4640e-01, 2.3966e-01, 9.9982e-01, -4.6795e-01,
-6.4802e-01, -4.1201e-01, -3.4984e-01, 3.5475e-01, -5.4668e-01,
-1.0000e+00, 3.6903e-01, -1.7324e-01, 4.3267e-01, -4.7206e-01,
6.3586e-01, -5.2151e-01, -9.9077e-01, -1.6597e-01, 2.6735e-01,
4.5069e-01, -4.3034e-01, -5.6321e-01, 5.7792e-01, 8.8123e-02,
9.4964e-01, 9.2798e-01, -3.3326e-01, 5.1963e-01, 6.0865e-01,
-4.4019e-01, -6.8129e-01, 9.3489e-01]], grad_fn=<TanhBackward>)
>>> output[2].shape
torch.Size([1, 768])
>>>
@TheEdoardo93 it doesn't matter whether it is GPT2 or bert. Both has the same error :
I'm trying to play with GPT2 that's why I pasted my own code.Using BERT as an Encoder and Decoder
>>> model = PreTrainedEncoderDecoder.from_pretrained('bert-base-uncased','bert-base-uncased') >>> ouput = model(torch.tensor( encoder_input_ids).unsqueeze(0)) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/guest_1/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__ result = self.forward(*input, **kwargs) TypeError: forward() missing 1 required positional argument: 'decoder_input_ids'
I got the same issue
Do you try to follow my suggestion reported above? In my environment, it works as expected.
My environment details:
pip install transformers)If not, can you post your environment and a list of steps to reproduce the bug?
I got the same issue
Said so, I've tested your code and I've revealed how to working with PreTrainedEncoderDecoder correctly without bugs. You can see my code below.
@TheEdoardo93 that doesn't make sense you giving your encoders input as a decoder's input.
Said so, I've tested your code and I've revealed how to working with PreTrainedEncoderDecoder correctly without bugs. You can see my code below.
@TheEdoardo93 that doesn't make sense your giving your encoders input as a decoder's input.
Never-mind, I know what's the issue is so closing it.
Sorry, it was my mistake. Can you share with us what was the problem and how to solve it?
@anandhperumal Could you please share how you solved the issue? did you path a <BOS> to decoder input?
Appreciate that
@anandhperumal can you let us know how you fixed the issue?
You can pass a start token to the decoder, just like a Seq2Seq Arch.