When I use the given pretrained model transformer.wmt16.en-de from paper Scaling Neural Machine Translation, here reported a matrix mismatch error:
python interactive.py ../wmt16.en-de.joined-dict.transformer/ --path ../wmt16.en-de.joined-dict.transformer/model.pt --task translation --remove-bpe -s en -t de
Namespace(beam=5, bpe=None, buffer_size=1, cpu=False, criterion='cross_entropy', data='../wmt16.en-de.joined-dict.transformer/', dataset_impl=None, decoding_iterations=None, decoding_strategy='left_to_right', dehyphenate=False, diverse_beam_groups=-1, diverse_beam_strength=0.5, force_anneal=None, fp16=False, fp16_init_scale=128, fp16_scale_tolerance=0.0, fp16_scale_window=None, gen_subset='test', gold_target_len=False, input='-', lazy_load=False, left_pad_source='True', left_pad_target='False', length_beam=5, lenpen=1, log_format=None, log_interval=1000, lr_scheduler='fixed', lr_shrink=0.1, match_source_len=False, max_len_a=0, max_len_b=200, max_sentences=1, max_source_positions=1024, max_target_positions=1024, max_tokens=None, memory_efficient_fp16=False, min_len=1, min_loss_scale=0.0001, model_overrides='{}', momentum=0.99, nbest=1, no_beamable_mm=False, no_early_stop=False, no_progress_bar=False, no_repeat_ngram_size=0, num_shards=1, num_workers=0, optimizer='nag', path='../wmt16.en-de.joined-dict.transformer/model.pt', prefix_size=0, print_alignment=False, quiet=False, raw_text=False, remove_bpe='@@ ', replace_unk=None, required_batch_size_multiple=8, results_path=None, sacrebleu=False, sampling=False, sampling_topk=-1, sampling_topp=-1.0, score_reference=False, seed=1, shard_id=0, skip_invalid_size_inputs_valid_test=False, source_lang='en', target_lang='de', task='translation', tbmf_wrapper=False, temperature=1.0, tensorboard_logdir='', threshold_loss_scale=None, tokenizer=None, unkpen=0, unnormalized=False, upsample_primary=1, user_dir=None, warmup_updates=0, weight_decay=0.0)
| [en] dictionary: 32769 types
| [de] dictionary: 32769 types
| loading model(s) from ../wmt16.en-de.joined-dict.transformer/model.pt
Traceback (most recent call last):
File "interactive.py", line 195, in <module>
cli_main()
File "interactive.py", line 191, in cli_main
main(args)
File "interactive.py", line 84, in main
task=task,
File "/root/code/ft_local/Mask-Predict-master/fairseq/checkpoint_utils.py", line 156, in load_model_ensemble
ensemble, args, _task = load_model_ensemble_and_task(filenames, arg_overrides, task)
File "/root/code/ft_local/Mask-Predict-master/fairseq/checkpoint_utils.py", line 175, in load_model_ensemble_and_task
model.load_state_dict(state['model'], strict=True)
File "/root/code/ft_local/Mask-Predict-master/fairseq/models/fairseq_model.py", line 72, in load_state_dict
return super().load_state_dict(state_dict, strict)
File "/root/miniconda2/envs/py3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 839, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for TransformerModel:
size mismatch for encoder.embed_tokens.weight: copying a param with shape torch.Size([32768, 1024]) from checkpoint, the shape in current model is torch.Size([32769, 1024]).
size mismatch for decoder.embed_tokens.weight: copying a param with shape torch.Size([32768, 1024]) from checkpoint, the shape in current model is torch.Size([32769, 1024]).
@myleott
How did you pre-process the data? If you want to use the pre-trained model provided in the README, you'll need to provide the dictionaries from the tar file. Specifically:
fairseq-preprocess \
--source-lang en --target-lang de \
--trainpref $TEXT/train.tok.clean.bpe.32000 \
--validpref $TEXT/newstest2013.tok.bpe.32000 \
--testpref $TEXT/newstest2014.tok.bpe.32000 \
--destdir data-bin/wmt16_en_de_bpe32k --workers 20 \
--joined-dictionary --srcdict wmt16.en-de.joined-dict.transformer/dict.en.txt
Thanks! @lematt1991
Thanks! @lematt1991
Most helpful comment
How did you pre-process the data? If you want to use the pre-trained model provided in the README, you'll need to provide the dictionaries from the tar file. Specifically: