Fairseq: Multilingual_translation error at inference

Created on 12 Feb 2019  路  10Comments  路  Source: pytorch/fairseq

For the preprocessing, you can use the same commands in the documentation for each language pair, for example:

fairseq-preprocess --source-lang de --target-lang en \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/
then you execute the same command for the second language pair:

fairseq-preprocess --source-lang it --target-lang en \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/

As for the training, here is a sample command that I used:
python train.py \raw-data\data-bin --task multilingual_translation --criterion label_smoothed_cross_entropy --arch multilingual_transformer --max-epoch 26 --lr 1.0 --wd 0.5 --lang-pairs de-en,it-en --encoder-layers 2 --decoder-layers 2 --save-dir data\checkpoints --optimizer sgd

Inference:
fairseq-interactive \raw-data\data-bin --task multilingual_translation --source-lang it --target-lang en --path \checkpoints\checkpoint20.pt --input \raw-data\test.it --beam 5

For the inference however, I am stuck at this error:
Traceback (most recent call last):

File "c:...\lib\site-packages\fairseq_cliinteractive.py", line 82, in main
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
File "c:...\lib\site-packages\fairseq\utils.py", line 164, in load_ensemble_for_inference
model = task.build_model(args)
File "c:...\lib\site-packages\fairseq\tasks\multilingual_translation.py", line 180, in build_model
model = models.build_model(args, self)
File "c:...\lib\site-packages\fairseq\models__init__.py", line 33, in build_model
return ARCH_MODEL_REGISTRY[args.arch].build_model(args, task)
File "c:....\lib\site-packages\fairseq\models\multilingual_transformer.py", line 162, in build_model
encoders[lang_pair] = shared_encoder if shared_encoder is not None else get_encoder(src)
File "c:.....\lib\site-packages\fairseq\models\multilingual_transformer.py", line 137, in get_encoder
task.dicts[lang], args.encoder_embed_dim, args.encoder_embed_path
KeyError: 'de'

It seems like it requires the second source language (de), but I really don't know why, and how to solve this. I hope someone tells me what I am missing.

_Originally posted by @AyaNsar in https://github.com/pytorch/fairseq/issues/497#issuecomment-462551184_

Most helpful comment

Thanks for reporting the error. This is a bug, and I will fix the problem shortly.

All 10 comments

Thanks for reporting the error. This is a bug, and I will fix the problem shortly.

Any new updates regarding the error?

[edited the inline code]

The fix is still under review. To unblock you, here is the workaround:
Replace the code with the following:

        args.lang_pairs = args.lang_pairs.split(',')
        if args.source_lang is not None or args.target_lang is not None:
            training = False
        else:
            training = True
            args.source_lang, args.target_lang = args.lang_pairs[0].split('-')

In the inference, we need to add the same --lang-pairs xxxx as the training input. In @AyaNsar's example, the inference will be:

fairseq-interactive \raw-data\data-bin --task multilingual_translation --source-lang it --target-lang en --path \checkpoints\checkpoint20.pt --input \raw-data\test.it --beam 5 --lang-pairs de-en,it-en

sorry, I changed the workaround above and we need one more change. Change this line to the following:

eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang),

Thanks a lot for the update!
I understand what you have done, and I changed the code just like you said, only that I left out line 100:
langs = list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')})

because if not, then this error comes up: NameError: name 'langs' is not defined.

However, another issue comes up now:

File "\fairseq-interactive.exe__main__.py", line 9, in
File "\fairseq_cliinteractive.py", line 189, in cli_main
main(args)
File "\fairseq_cliinteractive.py", line 172, in main
for batch, batch_indices in make_batches(inputs, args, task, max_positions):
File "\fairseq_cliinteractive.py", line 49, in make_batches
max_positions=max_positions,
File "\fairseq\tasks\fairseq_task.py", line 152, in get_batch_iterator
num_workers=num_workers,
File "\fairseq\data\iterators.py", line 89, in __init__
self.frozen_batches = tuple(batch_sampler)
File "\fairseq\data\data_utils.py", line 157, in batch_by_size
for idx in indices:
File "\fairseq\data\data_utils.py", line 106, in filter_by_size
for idx in itr:
File "\fairseq\data\data_utils.py", line 69, in collect_filtered
if function(el):
File "\fairseq\data\data_utils.py", line 92, in check_size
assert isinstance(idx_size, dict)
AssertionError

ah, sorry. for that we need another patch:

change this to

        dataset=task.build_dataset(tokens, lengths, task.source_dictionary) \
                if 'build_dataset' in dir(task) else \
                data.LanguagePairDataset(tokens, lengths, task.source_dictionary),

and add this function to fairseq/tasks/multilingual_translation.py

    def build_dataset(self, tokens, lengths, src_dict):
        lang_pair = "%s-%s" % (self.args.source_lang, self.args.target_lang)
        return RoundRobinZipDatasets(
            OrderedDict([
                (lang_pair, LanguagePairDataset(tokens, lengths, src_dict))
            ]),
            eval_key=lang_pair,
        )

can't thank you enough! It works like a charm

The fix is merged into master [1]. Please pull the latest version.
After the fix, we don't need to specify --lang-pairs during inference time anymore.

[1] https://github.com/pytorch/fairseq/pull/505

Sorry, the PR is not in master yet. I will keep the issue open until the merge happens.

PR is already merged. Closing.

Was this page helpful?
0 / 5 - 0 ratings