Hi,
i'm trying to load a fine tuned model for question answering which i trained with squad.py:
import torch
from pytorch_pretrained_bert import BertModel, BertForQuestionAnswering
from pytorch_pretrained_bert import modeling
config = modeling.BertConfig(attention_probs_dropout_prob=0.1, hidden_dropout_prob=0.1, hidden_size=768, initializer_range=0.02, intermediate_size=3072, max_position_embeddings=512, num_attention_heads=12, num_hidden_layers=12, vocab_size_or_config_json_file=30522)
model = modeling.BertForQuestionAnswering(config)
model_state_dict = "/home/ubuntu/bert_squad/bert_fine_121918/pytorch_model.bin"
model.bert.load_state_dict(torch.load(model_state_dict))
but receiving an error on the last line:
Error(s) in loading state_dict for BertModel:
Missing key(s) in state_dict: "embeddings.word_embeddings.weight", "embeddings.position_embeddings.weight", "embeddings.token_type_embeddings.weight", "embeddings.LayerNorm.weight", "embeddings.LayerNorm.bias", "encoder.layer.0.attention.self.query.weight",....
Unexpected key(s) in state_dict: "bert.embeddings.word_embeddings.weight", "bert.embeddings.position_embeddings.weight", "bert.embeddings.token_type_embeddings.weight", "bert.embeddings.LayerNorm.weight", "bert.embeddings.LayerNorm.bias", "bert.encoder.layer.0.attention.self.query.weight",....
it looks like model definition is not in expected format. Could you direct me on what went wrong?
Judging from the error message, I would say that the error is caused by the following line: https://github.com/huggingface/pytorch-pretrained-BERT/blob/7fb94ab934b2ad1041613fc93c61d13105faf98a/pytorch_pretrained_bert/modeling.py#L541
Apparently, the proper way to save a model is the following one:
https://github.com/huggingface/pytorch-pretrained-BERT/blob/7fb94ab934b2ad1041613fc93c61d13105faf98a/examples/run_classifier.py#L554-L557
Is this what you are doing?
hi @rodgzilla i see that model is being saved the same way in squad.py:
https://github.com/huggingface/pytorch-pretrained-BERT/blob/7fb94ab934b2ad1041613fc93c61d13105faf98a/examples/run_squad.py#L918-L921
so the problem must be elsewhere
I run into the same problem, using the pytorch_model.bin generated by run_classifier.py:
!python pytorch-pretrained-BERT/examples/run_classifier.py \
--task_name=MRPC \
--do_train \
--do_eval \
--data_dir=./ \
--bert_model=bert-base-chinese \
--max_seq_length=64 \
--train_batch_size=32 \
--learning_rate=2e-5 \
--num_train_epochs=3.0 \
--output_dir=./models/
And try to load the fine-tuned model:
from pytorch_pretrained_bert import modeling
from pytorch_pretrained_bert import BertForSequenceClassification
# Load pre-trained model (weights)
config = modeling.BertConfig(
vocab_size_or_config_json_file=21128,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02)
model = BertForSequenceClassification(config)
model_state_dict = "models/pytorch_model.bin"
model.bert.load_state_dict(torch.load(model_state_dict))
RuntimeError Traceback (most recent call last)
<ipython-input-22-cdc19dc2541c> in <module>()
20 # issues: https://github.com/huggingface/pytorch-pretrained-BERT/issues/138
21 model_state_dict = "models/pytorch_model.bin"
---> 22 model.bert.load_state_dict(torch.load(model_state_dict))
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
767 if len(error_msgs) > 0:
768 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 769 self.__class__.__name__, "\n\t".join(error_msgs)))
770
771 def _named_members(self, get_members_fn, prefix='', recurse=True):
RuntimeError: Error(s) in loading state_dict for BertModel:
Missing key(s) in state_dict: "embeddings.word_embeddings.weight", "embeddings.position_embeddings.weight", "embeddings.token_type_embeddings.weight", "embeddings.LayerNorm.weight", "embeddings.LayerNorm.bias", "encoder.layer.0.attention.self.query.weight", "encoder.layer.0.attention.self.query.bias", "encoder.layer.0.attention.self.key.weight", "encoder.layer.0.attention.self.key.bias", "encoder.layer.0.attention.self.value.weight", "encoder.layer.0.attention.self.value.bias", "encoder.layer.0.attention.output.dense.weight", "encoder.layer.0.attention.output.dense.bias", "encoder.layer.0.attention.output.LayerNorm.weight", "encoder.layer.0.attention.output.LayerNorm.bias", "encoder.layer.0.intermediate.dense.weight", "encoder.layer.0.intermediate.dense.bias", "encoder.layer.0.output.dense.weight", "encoder.layer.0.output.dense.bias", "encoder.layer.0.output.LayerNorm.weight", "encoder.layer.0.output.LayerNorm.bias", "encoder.layer.1.attention.self.query.weight", "encoder.layer.1.attention.self.query.bias", "encoder.layer.1.attention.self.key.weight", "encoder.layer.1.attention.self.key.bias", "encoder.layer.1.attention.self.value.weight", "encoder.layer.1.attention.self.value.bias", "encoder.layer.1.attention.output.dense.weight", "encoder.layer.1.attention.output.dense.bias", "encoder.layer.1.attention.output.LayerNorm.weight", "encoder.layer.1.attention.output.LayerNorm.bias", "encoder.layer.1.intermediate.dense.weight", "encoder.layer.1.intermediate.dense.bias", "enco...
Unexpected key(s) in state_dict: "bert.embeddings.word_embeddings.weight", "bert.embeddings.position_embeddings.weight", "bert.embeddings.token_type_embeddings.weight", "bert.embeddings.LayerNorm.weight", "bert.embeddings.LayerNorm.bias", "bert.encoder.layer.0.attention.self.query.weight", "bert.encoder.layer.0.attention.self.query.bias", "bert.encoder.layer.0.attention.self.key.weight", "bert.encoder.layer.0.attention.self.key.bias", "bert.encoder.layer.0.attention.self.value.weight", "bert.encoder.layer.0.attention.self.value.bias", "bert.encoder.layer.0.attention.output.dense.weight", "bert.encoder.layer.0.attention.output.dense.bias", "bert.encoder.layer.0.attention.output.LayerNorm.weight", "bert.encoder.layer.0.attention.output.LayerNorm.bias", "bert.encoder.layer.0.intermediate.dense.weight", "bert.encoder.layer.0.intermediate.dense.bias", "bert.encoder.layer.0.output.dense.weight", "bert.encoder.layer.0.output.dense.bias", "bert.encoder.layer.0.output.LayerNorm.weight", "bert.encoder.layer.0.output.LayerNorm.bias", "bert.encoder.layer.1.attention.self.query.weight", "bert.encoder.layer.1.attention.self.query.bias", "bert.encoder.layer.1.attention.self.key.weight", "bert.encoder.layer.1.attention.self.key.bias", "bert.encoder.layer.1.attention.self.value.weight", "bert.encoder.layer.1.attention.self.value.bias", "bert.encoder.layer.1.attention.output.dense.weight", "bert.encoder.layer.1.attention.output.dense.bias", "bert.encoder.layer.1.attention.output.LayerNorm....
How can I load a fine-tuned model?
Hi, here the problem is not with the saving of the model but the loading.
You should just use
model.load_state_dict(torch.load(model_state_dict))
and not
model.bert.load_state_dict(torch.load(model_state_dict))
Alternatively, here is an example on how to save and then load a model using from_pretrained:
Most helpful comment
Hi, here the problem is not with the saving of the model but the loading.
You should just use
and not
Alternatively, here is an example on how to save and then load a model using
from_pretrained:https://github.com/huggingface/pytorch-pretrained-BERT/blob/2e4db64cab198dc241e18221ef088908f2587c61/examples/run_squad.py#L916-L924