Transformers: How to save a model as a BertModel

Created on 7 Dec 2019  ·  3Comments  ·  Source: huggingface/transformers

❓ Questions & Help

I first fine-tuned a bert-base-uncased model on SST-2 dataset with run_glue.py. Then i want to use the output pytorch_model.bin to do a further fine-tuning on MNLI dataset. But if i directly use this pytorch_model.bin, an error will occur:

RuntimeError: Error(s) in loading state_dict for BertForSequenceClassification:
size mismatch for classifier.weight: copying a param with shape torch.Size([2, 768]) from checkpoint, the shape in current model is torch.Size([3, 768]).
size mismatch for classifier.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([3]).

This error occurred because SST-2 has two classes but MNLI has three classes. Issue #1108 provide a solution by saving the BertModel without the classification head. But i wander if it‘s feasible for that the model class is chosen as BertForSequenceClassification at the beginning. How do i change the model class in saving step?

wontfix

Most helpful comment

Hello! If you try to load your pytorch_model.bin directly in BertForSequenceClassification, you'll indeed get an error as the model won't know that it is supposed to have three classes. That's what the configuration is for!

I guess you're doing something similar to this:

from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained("bert-base-cased")
model.load_state_dict(torch.load("SAVED_SST_MODEL_DIR/pytorch_model.bin"))
# Crashes here

Instead, if you saved using the save_pretrained method, then the directory already should have a config.json specifying the shape of the model, so you can simply load it using:

from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained("SAVED_SST_MODEL_DIR")

If you didn't save it using save_pretrained, but using torch.save or another, resulting in a pytorch_model.bin file containing your model state dict, you can initialize a configuration from your initial configuration (in this case I guess it's bert-base-cased) and assign three classes to it. You can then load your model by specifying which configuration to use:

from transformers import BertForSequenceClassification, BertConfig

config = BertConfig.from_pretrained("bert-base-cased", num_labels=3)
model = BertForSequenceClassification.from_pretrained("bert-base-cased", config=config)
model.load_state_dict(torch.load("SAVED_SST_MODEL_DIR/pytorch_model.bin"))

Let me know how it works out for you.

All 3 comments

Hello! If you try to load your pytorch_model.bin directly in BertForSequenceClassification, you'll indeed get an error as the model won't know that it is supposed to have three classes. That's what the configuration is for!

I guess you're doing something similar to this:

from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained("bert-base-cased")
model.load_state_dict(torch.load("SAVED_SST_MODEL_DIR/pytorch_model.bin"))
# Crashes here

Instead, if you saved using the save_pretrained method, then the directory already should have a config.json specifying the shape of the model, so you can simply load it using:

from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained("SAVED_SST_MODEL_DIR")

If you didn't save it using save_pretrained, but using torch.save or another, resulting in a pytorch_model.bin file containing your model state dict, you can initialize a configuration from your initial configuration (in this case I guess it's bert-base-cased) and assign three classes to it. You can then load your model by specifying which configuration to use:

from transformers import BertForSequenceClassification, BertConfig

config = BertConfig.from_pretrained("bert-base-cased", num_labels=3)
model = BertForSequenceClassification.from_pretrained("bert-base-cased", config=config)
model.load_state_dict(torch.load("SAVED_SST_MODEL_DIR/pytorch_model.bin"))

Let me know how it works out for you.

Yes!!!
Setting the num_labels is useful!
And I found that if i delete the classifier.weights and classifier.bias before i use torch.save(model_to_save.state_dict(), output_model_file), the pytorch_model.bin will be loaded well when further fine-tuning. And this model can be also used for QA or MultipleChoice.

Hello! If you try to load your pytorch_model.bin directly in BertForSequenceClassification, you'll indeed get an error as the model won't know that it is supposed to have three classes. That's what the configuration is for!

I guess you're doing something similar to this:

from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained("bert-base-cased")
model.load_state_dict(torch.load("SAVED_SST_MODEL_DIR/pytorch_model.bin"))
# Crashes here

Instead, if you saved using the save_pretrained method, then the directory already should have a config.json specifying the shape of the model, so you can simply load it using:

from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained("SAVED_SST_MODEL_DIR")

If you didn't save it using save_pretrained, but using torch.save or another, resulting in a pytorch_model.bin file containing your model state dict, you can initialize a configuration from your initial configuration (in this case I guess it's bert-base-cased) and assign three classes to it. You can then load your model by specifying which configuration to use:

from transformers import BertForSequenceClassification, BertConfig

config = BertConfig.from_pretrained("bert-base-cased", num_labels=3)
model = BertForSequenceClassification.from_pretrained("bert-base-cased", config=config)
model.load_state_dict(torch.load("SAVED_SST_MODEL_DIR/pytorch_model.bin"))

Let me know how it works out for you.

Yes!!!
Setting the num_labels is useful!
And I found that if i delete the classifier.weights and classifier.bias before i use torch.save(model_to_save.state_dict(), output_model_file), the pytorch_model.bin will be loaded well when further fine-tuning. And this model can be also used for QA or MultipleChoice.

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Was this page helpful?
0 / 5 - 0 ratings