Transformers: Bert Checkpoint Breaks 3.02 -> 3.1.0 due to new buffer in BertEmbeddings

Created on 1 Sep 2020  ยท  22Comments  ยท  Source: huggingface/transformers

Hi,

Thanks for the great library. I noticed this line being added (https://github.com/huggingface/transformers/blob/v3.1.0/src/transformers/modeling_bert.py#L190) in the latest update.

It breaks checkpoints that were saved when this line wasn't there.

    Missing key(s) in state_dict: "generator_model.electra.embeddings.position_ids", "discriminator_model.electra.embeddings.position_ids". 
wontfix

Most helpful comment

You can also use the load_state_dict method with the strict option set to False:

model.load_state_dict(state_dict, strict=False)

All 22 comments

I understand it makes the code slightly cleaner; in terms of speed it is most likely negligible (compared to the embedding lookup, for example).

But not sure what to do now as all the pretrained models (that used a lot of compute to pretrain) don't work anymore in the new update.

Hey @Laksh1997 - note that this line does not break anything. You can neglect warnings about position_ids since those are created at instantiation. Will open a PR to fix the warning

@patrickvonplaten seems to break it for me:

๏…
16:43:52
Traceback (most recent call last):
๏…
16:43:52
File "/opt/conda/envs/py36/bin/transformervae", line 33, in <module>
๏…
16:43:52
sys.exit(load_entry_point('exs-transformervae', 'console_scripts', 'transformervae')())
๏…
16:43:52
File "/opt/conda/envs/py36/lib/python3.6/site-packages/click/core.py", line 829, in __call__
๏…
16:43:52
return self.main(*args, **kwargs)
๏…
16:43:52
File "/opt/conda/envs/py36/lib/python3.6/site-packages/click/core.py", line 782, in main
๏…
16:43:52
rv = self.invoke(ctx)
๏…
16:43:52
File "/opt/conda/envs/py36/lib/python3.6/site-packages/click/core.py", line 1259, in invoke
๏…
16:43:52
return _process_result(sub_ctx.command.invoke(sub_ctx))
๏…
16:43:52
File "/opt/conda/envs/py36/lib/python3.6/site-packages/click/core.py", line 1066, in invoke
๏…
16:43:52
return ctx.invoke(self.callback, **ctx.params)
๏…
16:43:52
File "/opt/conda/envs/py36/lib/python3.6/site-packages/click/core.py", line 610, in invoke
๏…
16:43:52
return callback(*args, **kwargs)
๏…
16:43:52
File "/app/transformervae/cli.py", line 355, in train
๏…
16:43:52
model = model_cls(hparams, pretrained_model=pretrained_model_path_or_config)
๏…
16:43:52
File "/app/transformervae/models/regression.py", line 35, in __init__
๏…
16:43:52
pretrained_model,
๏…
16:43:52
File "/app/transformervae/models/finetuning_model.py", line 37, in __init__
๏…
16:43:52
self.encoder, self.tokenizer = self.load_pretrained_encoder(pretrained_model)
๏…
16:43:52
File "/app/transformervae/models/finetuning_model.py", line 89, in load_pretrained_encoder
๏…
16:43:52
pl_model = AutoModel.load(pretrained_model)
๏…
16:43:52
File "/app/transformervae/models/automodel.py", line 98, in load
๏…
16:43:52
return model_cls.load(path)
๏…
16:43:52
File "/app/transformervae/models/base.py", line 229, in load
๏…
16:43:52
return cls.load_from_checkpoint(filepath)
๏…
16:43:52
File "/opt/conda/envs/py36/lib/python3.6/site-packages/pytorch_lightning/core/saving.py", line 169, in load_from_checkpoint
๏…
16:43:52
model = cls._load_model_state(checkpoint, *args, **kwargs)
๏…
16:43:52
File "/opt/conda/envs/py36/lib/python3.6/site-packages/pytorch_lightning/core/saving.py", line 207, in _load_model_state
๏…
16:43:52
model.load_state_dict(checkpoint['state_dict'])
๏…
16:43:52
File "/opt/conda/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1045, in load_state_dict
๏…
16:43:52
self.__class__.__name__, "\n\t".join(error_msgs)))
๏…
16:43:52
RuntimeError: Error(s) in loading state_dict for ElectraLanguageModel:
๏„ฟ
16:43:52
Missing key(s) in state_dict: "generator_model.electra.embeddings.position_ids", "discriminator_model.electra.embeddings.position_ids".

Note, generator_model.electra is ElectraModel, which uses BertEmbeddings.

Can you send me a code snippet so that I can reproduce your error?

It's a big library. But I can try to recreate in a Colab. One sec.

@patrickvonplaten Colab: https://colab.research.google.com/drive/167CwTImG5T-4c9xeIVEkH9Xrracbn30h?usp=sharing

Let me know if you can access?

It also breaks to me. The attribute embedding.position_ids can't be loaded if the model artifact is trained with v3.0.2. So it will raise an KeyError

Hey @Laksh1997, I can't access the notebook - could you make it public for everybody to see? :-)

@patrickvonplaten apologies. Here is the script:

!pip install transformers==3.0.2

from transformers import ElectraModel, ElectraConfig
import torch
import transformers

print(transformers.__version__)

model = ElectraModel(ElectraConfig())
state_dict = model.state_dict()
torch.save(state_dict, 'checkpoint.pt')
!pip install transformers==3.1.0

from transformers import ElectraModel, ElectraConfig
import torch
import transformers

print(transformers.__version__)

model = ElectraModel(ElectraConfig())
state_dict = torch.load('checkpoint.pt')
model.load_state_dict(state_dict)

I encountered the same issue. Old checkpoints (3.0.2) can not be loaded in (3.1.0) due to KeyError.

@Barcavin @easonnie As a temporary fix, I've just reverted back to 3.0.2. @patrickvonplaten I am hoping something can be done !

Hi, while we work on patching this issue, you can still use version v3.1.0 by using the from_pretrained method. Taking @Laksh1997's example, you would do:

  1. Save the checkpoint in saved_model_location/pytorch_model.bin
from transformers import ElectraModel, ElectraConfig
import torch
import transformers

print(transformers.__version__)

model = ElectraModel(ElectraConfig())
state_dict = model.state_dict()
torch.save(state_dict, 'saved_model_location/pytorch_model.bin')
  1. Load it using the method .from_pretrained
from transformers import ElectraModel, ElectraConfig
import transformers

print(transformers.__version__)

model = ElectraModel.from_pretrained("saved_model_location", config=ElectraConfig())

You can also use the load_state_dict method with the strict option set to False:

model.load_state_dict(state_dict, strict=False)

The reason this additional buffer is here now is due to this PR.

Is there a reason why you would use the load_state_dict instead of from_pretrained, as from_pretrained exists in part to prevent such issues from happening?

Hi @LysandreJik

Thanks for the proposed solution.

In my case, I am using Pytorch Lightning which has its own saving and loading infrastructure. Thus the from_pretrained method can't exactly be used.

The strict flag is a good patch for now.

I think, in general, when building on top of the library, for complex projects one cannot rely on from_pretrained, especially if using other ecosystems.

Using the strict flag can enable a number of errors to go undetected, so I would refrain from using it. I think the best solution is to use version 3.0.2 for already trained models until the fix comes out.

Any update on this @LysandreJik @patrickvonplaten ?

As the torch.load method in strict mode does not allow unexpected/missing keys, this is an issue that won't be resolved. Three options are available here:

  • Use the recommended from_pretrained method, which exists specifically to work around this kind of issues
  • Use the torch.load method with the strict flag set to False
  • Pin to version v3.0.2 if none of these can be applied.

Minor changes in model infrastructure can unfortunately happen as we try to optimize for efficiency, which will lead to this kind of issues. We're internally working on having our models on the hub be versionable, which should solve most of these problems. It's at least a couple of months away, however.

@LysandreJik That is unfortunate that the library will probably have to be pinned, as the first two options are unviable for reasons described in this thread. Especially because pretraining large models is computationally quite expensive (100s of GPU hours)...

You can also use the work-around explained here if you want to convert your weights to the updated architecture.

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

ereday picture ereday  ยท  3Comments

fabiocapsouza picture fabiocapsouza  ยท  3Comments

quocnle picture quocnle  ยท  3Comments

rsanjaykamath picture rsanjaykamath  ยท  3Comments

siddsach picture siddsach  ยท  3Comments