I have trained a small Transformer model on WMT18 De-En data. Now, I want to do transfer learning using the pretrained weights on a different language pair, say, X-En, where X is any other language. Obviously, the vocabulary need not be the same for this new language, and therefore, I would need to reinitialize word embeddings on the source side.
In reference to the linked issue :
It seems the issue at hand was different - how to load the model weights, not those of the embeddings. Currently, I am experiencing the same problem and would like to know if there is some existing solution.
_Originally posted by @skeshaw in https://github.com/pytorch/fairseq/issues/429#issuecomment-535927999_
Currently, there's nothing in place to do this, however, you can modify the upgrade_state_dict_named function in the Transformer class to exclude the embedding weights. Maybe something like:
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
# Delete the weights for the embedding table...
for k in state_dict.keys():
if k.contains('embed_tokens'):
del state_dict[k]
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = '{}.embed_positions.weights'.format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
for i in range(len(self.layers)):
# update layer norms
self.layers[i].upgrade_state_dict_named(state_dict, "{}.layers.{}".format(name, i))
version_key = '{}.version'.format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict
Hi Matt,
Thanks for your reply.
I tried what you suggested, but it doesn't seem to work. To minimize the number of changes, I also reused the target dictionary, as the target language is English in both cases.
In each instance, I am getting the same error as earlier.
As you can see in the first line below, I have successfully deleted the embedding weights:
odict_keys(['encoder.version', 'encoder.embed_positions._float_tensor', 'encoder.layers.0.self_attn.in_proj_weight', 'encoder.layers.0.self_attn.in_proj_bias', 'encoder.layers.0.self_attn.out_proj.weight', 'encoder.layers.0.self_attn.out_proj.bias', 'encoder.layers.0.self_attn_layer_norm.weight', 'encoder.layers.0.self_attn_layer_norm.bias', 'encoder.layers.0.fc1.weight', 'encoder.layers.0.fc1.bias', 'encoder.layers.0.fc2.weight', 'encoder.layers.0.fc2.bias', 'encoder.layers.0.final_layer_norm.weight', 'encoder.layers.0.final_layer_norm.bias', 'encoder.layers.1.self_attn.in_proj_weight', 'encoder.layers.1.self_attn.in_proj_bias', 'encoder.layers.1.self_attn.out_proj.weight', 'encoder.layers.1.self_attn.out_proj.bias', 'encoder.layers.1.self_attn_layer_norm.weight', 'encoder.layers.1.self_attn_layer_norm.bias', 'encoder.layers.1.fc1.weight', 'encoder.layers.1.fc1.bias', 'encoder.layers.1.fc2.weight', 'encoder.layers.1.fc2.bias', 'encoder.layers.1.final_layer_norm.weight', 'encoder.layers.1.final_layer_norm.bias', 'encoder.layers.2.self_attn.in_proj_weight', 'encoder.layers.2.self_attn.in_proj_bias', 'encoder.layers.2.self_attn.out_proj.weight', 'encoder.layers.2.self_attn.out_proj.bias', 'encoder.layers.2.self_attn_layer_norm.weight', 'encoder.layers.2.self_attn_layer_norm.bias', 'encoder.layers.2.fc1.weight', 'encoder.layers.2.fc1.bias', 'encoder.layers.2.fc2.weight', 'encoder.layers.2.fc2.bias', 'encoder.layers.2.final_layer_norm.weight', 'encoder.layers.2.final_layer_norm.bias', 'encoder.layers.3.self_attn.in_proj_weight', 'encoder.layers.3.self_attn.in_proj_bias', 'encoder.layers.3.self_attn.out_proj.weight', 'encoder.layers.3.self_attn.out_proj.bias', 'encoder.layers.3.self_attn_layer_norm.weight', 'encoder.layers.3.self_attn_layer_norm.bias', 'encoder.layers.3.fc1.weight', 'encoder.layers.3.fc1.bias', 'encoder.layers.3.fc2.weight', 'encoder.layers.3.fc2.bias', 'encoder.layers.3.final_layer_norm.weight', 'encoder.layers.3.final_layer_norm.bias', 'encoder.layers.4.self_attn.in_proj_weight', 'encoder.layers.4.self_attn.in_proj_bias', 'encoder.layers.4.self_attn.out_proj.weight', 'encoder.layers.4.self_attn.out_proj.bias', 'encoder.layers.4.self_attn_layer_norm.weight', 'encoder.layers.4.self_attn_layer_norm.bias', 'encoder.layers.4.fc1.weight', 'encoder.layers.4.fc1.bias', 'encoder.layers.4.fc2.weight', 'encoder.layers.4.fc2.bias', 'encoder.layers.4.final_layer_norm.weight', 'encoder.layers.4.final_layer_norm.bias', 'encoder.layers.5.self_attn.in_proj_weight', 'encoder.layers.5.self_attn.in_proj_bias', 'encoder.layers.5.self_attn.out_proj.weight', 'encoder.layers.5.self_attn.out_proj.bias', 'encoder.layers.5.self_attn_layer_norm.weight', 'encoder.layers.5.self_attn_layer_norm.bias', 'encoder.layers.5.fc1.weight', 'encoder.layers.5.fc1.bias', 'encoder.layers.5.fc2.weight', 'encoder.layers.5.fc2.bias', 'encoder.layers.5.final_layer_norm.weight', 'encoder.layers.5.final_layer_norm.bias', 'decoder.version', 'decoder.embed_positions._float_tensor', 'decoder.layers.0.self_attn.in_proj_weight', 'decoder.layers.0.self_attn.in_proj_bias', 'decoder.layers.0.self_attn.out_proj.weight', 'decoder.layers.0.self_attn.out_proj.bias', 'decoder.layers.0.self_attn_layer_norm.weight', 'decoder.layers.0.self_attn_layer_norm.bias', 'decoder.layers.0.encoder_attn.in_proj_weight', 'decoder.layers.0.encoder_attn.in_proj_bias', 'decoder.layers.0.encoder_attn.out_proj.weight', 'decoder.layers.0.encoder_attn.out_proj.bias', 'decoder.layers.0.encoder_attn_layer_norm.weight', 'decoder.layers.0.encoder_attn_layer_norm.bias', 'decoder.layers.0.fc1.weight', 'decoder.layers.0.fc1.bias', 'decoder.layers.0.fc2.weight', 'decoder.layers.0.fc2.bias', 'decoder.layers.0.final_layer_norm.weight', 'decoder.layers.0.final_layer_norm.bias', 'decoder.layers.1.self_attn.in_proj_weight', 'decoder.layers.1.self_attn.in_proj_bias', 'decoder.layers.1.self_attn.out_proj.weight', 'decoder.layers.1.self_attn.out_proj.bias', 'decoder.layers.1.self_attn_layer_norm.weight', 'decoder.layers.1.self_attn_layer_norm.bias', 'decoder.layers.1.encoder_attn.in_proj_weight', 'decoder.layers.1.encoder_attn.in_proj_bias', 'decoder.layers.1.encoder_attn.out_proj.weight', 'decoder.layers.1.encoder_attn.out_proj.bias', 'decoder.layers.1.encoder_attn_layer_norm.weight', 'decoder.layers.1.encoder_attn_layer_norm.bias', 'decoder.layers.1.fc1.weight', 'decoder.layers.1.fc1.bias', 'decoder.layers.1.fc2.weight', 'decoder.layers.1.fc2.bias', 'decoder.layers.1.final_layer_norm.weight', 'decoder.layers.1.final_layer_norm.bias', 'decoder.layers.2.self_attn.in_proj_weight', 'decoder.layers.2.self_attn.in_proj_bias', 'decoder.layers.2.self_attn.out_proj.weight', 'decoder.layers.2.self_attn.out_proj.bias', 'decoder.layers.2.self_attn_layer_norm.weight', 'decoder.layers.2.self_attn_layer_norm.bias', 'decoder.layers.2.encoder_attn.in_proj_weight', 'decoder.layers.2.encoder_attn.in_proj_bias', 'decoder.layers.2.encoder_attn.out_proj.weight', 'decoder.layers.2.encoder_attn.out_proj.bias', 'decoder.layers.2.encoder_attn_layer_norm.weight', 'decoder.layers.2.encoder_attn_layer_norm.bias', 'decoder.layers.2.fc1.weight', 'decoder.layers.2.fc1.bias', 'decoder.layers.2.fc2.weight', 'decoder.layers.2.fc2.bias', 'decoder.layers.2.final_layer_norm.weight', 'decoder.layers.2.final_layer_norm.bias', 'decoder.layers.3.self_attn.in_proj_weight', 'decoder.layers.3.self_attn.in_proj_bias', 'decoder.layers.3.self_attn.out_proj.weight', 'decoder.layers.3.self_attn.out_proj.bias', 'decoder.layers.3.self_attn_layer_norm.weight', 'decoder.layers.3.self_attn_layer_norm.bias', 'decoder.layers.3.encoder_attn.in_proj_weight', 'decoder.layers.3.encoder_attn.in_proj_bias', 'decoder.layers.3.encoder_attn.out_proj.weight', 'decoder.layers.3.encoder_attn.out_proj.bias', 'decoder.layers.3.encoder_attn_layer_norm.weight', 'decoder.layers.3.encoder_attn_layer_norm.bias', 'decoder.layers.3.fc1.weight', 'decoder.layers.3.fc1.bias', 'decoder.layers.3.fc2.weight', 'decoder.layers.3.fc2.bias', 'decoder.layers.3.final_layer_norm.weight', 'decoder.layers.3.final_layer_norm.bias', 'decoder.layers.4.self_attn.in_proj_weight', 'decoder.layers.4.self_attn.in_proj_bias', 'decoder.layers.4.self_attn.out_proj.weight', 'decoder.layers.4.self_attn.out_proj.bias', 'decoder.layers.4.self_attn_layer_norm.weight', 'decoder.layers.4.self_attn_layer_norm.bias', 'decoder.layers.4.encoder_attn.in_proj_weight', 'decoder.layers.4.encoder_attn.in_proj_bias', 'decoder.layers.4.encoder_attn.out_proj.weight', 'decoder.layers.4.encoder_attn.out_proj.bias', 'decoder.layers.4.encoder_attn_layer_norm.weight', 'decoder.layers.4.encoder_attn_layer_norm.bias', 'decoder.layers.4.fc1.weight', 'decoder.layers.4.fc1.bias', 'decoder.layers.4.fc2.weight', 'decoder.layers.4.fc2.bias', 'decoder.layers.4.final_layer_norm.weight', 'decoder.layers.4.final_layer_norm.bias', 'decoder.layers.5.self_attn.in_proj_weight', 'decoder.layers.5.self_attn.in_proj_bias', 'decoder.layers.5.self_attn.out_proj.weight', 'decoder.layers.5.self_attn.out_proj.bias', 'decoder.layers.5.self_attn_layer_norm.weight', 'decoder.layers.5.self_attn_layer_norm.bias', 'decoder.layers.5.encoder_attn.in_proj_weight', 'decoder.layers.5.encoder_attn.in_proj_bias', 'decoder.layers.5.encoder_attn.out_proj.weight', 'decoder.layers.5.encoder_attn.out_proj.bias', 'decoder.layers.5.encoder_attn_layer_norm.weight', 'decoder.layers.5.encoder_attn_layer_norm.bias', 'decoder.layers.5.fc1.weight', 'decoder.layers.5.fc1.bias', 'decoder.layers.5.fc2.weight', 'decoder.layers.5.fc2.bias', 'decoder.layers.5.final_layer_norm.weight', 'decoder.layers.5.final_layer_norm.bias'])
Traceback (most recent call last):
File "/home/fairseq/fairseq/trainer.py", line 178, in load_checkpoint
self.get_model().load_state_dict(state['model'], strict=True)
File "/home/fairseq/fairseq/models/fairseq_model.py", line 69, in load_state_dict
return super().load_state_dict(state_dict, strict)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 777, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for TransformerModel:
size mismatch for encoder.embed_tokens.weight: copying a param with shape torch.Size([21768, 300]) from checkpoint, the shape in current model is torch.Size([13056, 300]).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/bin/fairseq-train", line 11, in <module>
load_entry_point('fairseq', 'console_scripts', 'fairseq-train')()
File "/home/fairseq/fairseq_cli/train.py", line 327, in cli_main
main(args)
File "/home/fairseq/fairseq_cli/train.py", line 70, in main
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
File "/home/fairseq/fairseq/checkpoint_utils.py", line 109, in load_checkpoint
reset_meters=args.reset_meters,
File "/home/fairseq/fairseq/trainer.py", line 184, in load_checkpoint
'please ensure that the architectures match.'.format(filename)
Exception: Cannot load model parameters from checkpoint /home/ckpts/checkpoint_last.pt; please ensure that the architectures match.
But it looks to me that some changes need to be made elsewhere too. I would appreciate if you could guide me through that.
Sorry, I hadn't actually tested it. The following worked for me:
Modify upgrade_state_dict_named in the TransformerEncoder class as follows:
def upgrade_state_dict_named(self, state_dict, name):
# Keep the current weights for the encoder embedding table
for k in state_dict.keys():
if 'encoder.embed_tokens' in k:
state_dict[k] = self.embed_tokens.weight
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = '{}.embed_positions.weights'.format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
version_key = '{}.version'.format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict
I simply copy-pasted your code in the required location, it still does not seem to work. I am getting the following error now:
| epoch 007: 0% 0/37 [00:00<?, ?it/s]Traceback (most recent call last):
File "/usr/local/bin/fairseq-train", line 11, in <module>
load_entry_point('fairseq', 'console_scripts', 'fairseq-train')()
File "/home/fairseq/fairseq_cli/train.py", line 327, in cli_main
main(args)
File "/home/fairseq/fairseq_cli/train.py", line 81, in main
train(args, trainer, task, epoch_itr)
File "/home/fairseq/fairseq_cli/train.py", line 122, in train
log_output = trainer.train_step(samples)
File "/home/fairseq/fairseq/trainer.py", line 405, in train_step
self.optimizer.step()
File "/home/fairseq/fairseq/optim/fairseq_optimizer.py", line 98, in step
self.optimizer.step(closure)
File "/home/fairseq/fairseq/optim/adam.py", line 160, in step
exp_avg.mul_(beta1).add_(1 - beta1, grad)
RuntimeError: The size of tensor a (21768) must match the size of tensor b (13056) at non-singleton dimension 0
I was getting the exact same error yesterday, when instead of deleting the weights, I had manually initialized them using the Embedding function - essentially similar to what you have suggested here. To clarify, 21768 is the source dictionary size in the pretrained model, while 13056 is that for the new source language.
I will keep looking into the code to solve this, but if you have any other suggestions, those could really save me a lot of time. :)
Thanks.
I think this has to do with the old optimizer state. Can you try --reset-optimizer
Yes, you are right. It's working now.
BTW, would you mind explaining the purpose of this option? It would help me understand the library a little more.
Adam stores some state about each parameter that it is optimizing. Since you were previously optimizing parameters of a different shape, you'd need to either upgrade the adam state in a similar way you upgraded the state of the transformer params, or you can just start with fresh state (--reset-optimizer).
Thank you so much for taking the time.
Most helpful comment
Adam stores some state about each parameter that it is optimizing. Since you were previously optimizing parameters of a different shape, you'd need to either upgrade the adam state in a similar way you upgraded the state of the transformer params, or you can just start with fresh state (--reset-optimizer).