Dear all,
I have trained a multilingual Transformer with shared encoder and decoder. I would like to apply the same model for inferring in an unseen pair (although both languages have been seen in training, in other directions). If I try to generate with this unseen pair, with the corresponding source and target language arguments, Fairseq tries to load a model from a nonexistent key, and it gives an error. However, I just want to apply the existent model (which is shared for all languages), so it should be theoretically possible.
Which is the canonical way to do so? Is it implemented?
Many thanks in advance.
It isn't currently implemented, however, you should be able to hack something together by modifying/creating a new checkpoint from your original one. The multilingual models maintain a dictionary of encoders and decoders for each language pair. If you'd like to test on an unseen language pair, you'll get a KeyError exception. You should be able to modify your checkpoint by adding this new language pair to the checkpoint, copying the encoder and decoder you want to use.
It isn't currently implemented, however, you should be able to hack something together by modifying/creating a new checkpoint from your original one. The multilingual models maintain a dictionary of encoders and decoders for each language pair. If you'd like to test on an unseen language pair, you'll get a
KeyErrorexception. You should be able to modify your checkpoint by adding this new language pair to the checkpoint, copying the encoder and decoder you want to use.
That's more or less what I did, in the load_state_dict method. Probably it would be worth implementing it as a fully supported Fairseq feature. Thank you very much.
Hi @jordiae Could you share the code you have modified to test the model on unseen language pair? I tried to modify the code, but I am getting error in the generate function of the SequenceGenerator.
Hi @jordiae Could you share the code you have modified to test the model on unseen language pair? I tried to modify the code, but I am getting error in the generate function of the SequenceGenerator.
Dear @murthyrudra, sorry for the delay. It depends on the specific Fairseq version, but modifying the load_state_dict method in fairseq/models/multilingual_transformer.py worked for me. I just copied the weights from existing models. I did something like:
for model in self.models:
src_desired_model, tgt_desired_model = model.split('-')
for k, _ in state_dict.items():
assert k.startswith('models.')
k_elements = k.split('.')
new_k_elements = k_elements.copy()
src_existent_model, tgt_existent_model = k_elements[1].split('-')
if k_elements[2] == 'encoder' and src_desired_model == src_existent_model:
new_k_elements[1] = model
new_key = '.'.join(new_k_elements)
if new_key not in state_dict_subset:
state_dict_subset[new_key] = state_dict_subset[k].clone()
if k_elements[2] == 'decoder' and tgt_desired_model == tgt_existent_model:
new_k_elements[1] = model
new_key = '.'.join(new_k_elements)
if new_key not in state_dict_subset:
state_dict_subset[new_key] = state_dict_subset[k].clone()
Please note that it's quick&dirt code :)
@jordiae Thank, you very much :)
Hi @jordiae Could you share the code you have modified to test the model on unseen language pair? I tried to modify the code, but I am getting error in the generate function of the SequenceGenerator.
Dear @murthyrudra, sorry for the delay. It depends on the specific Fairseq version, but modifying the
load_state_dictmethod infairseq/models/multilingual_transformer.pyworked for me. I just copied the weights from existing models. I did something like:for model in self.models: src_desired_model, tgt_desired_model = model.split('-') for k, _ in state_dict.items(): assert k.startswith('models.') k_elements = k.split('.') new_k_elements = k_elements.copy() src_existent_model, tgt_existent_model = k_elements[1].split('-') if k_elements[2] == 'encoder' and src_desired_model == src_existent_model: new_k_elements[1] = model new_key = '.'.join(new_k_elements) if new_key not in state_dict_subset: state_dict_subset[new_key] = state_dict_subset[k].clone() if k_elements[2] == 'decoder' and tgt_desired_model == tgt_existent_model: new_k_elements[1] = model new_key = '.'.join(new_k_elements) if new_key not in state_dict_subset: state_dict_subset[new_key] = state_dict_subset[k].clone()Please note that it's quick&dirt code :)
Hello @jordiae, thanks for sharing your code snippet. I tried a zero shot inference using your version of load_state_dict, however the BLEU score gets to 0 even when using the original pairs. But, the BLEU is non-zero when the original load_state_dict method is used.
Hi @masonreznov, and sorry for the (very) late notice! I used that snippet in an old side project and really don't remember much about it at this moment. Probably, I copy-pasted by mistake an old version of the snippet I actually used... but at this moment I can't find it! If I do, I'll ping you.
Most helpful comment
Dear @murthyrudra, sorry for the delay. It depends on the specific Fairseq version, but modifying the
load_state_dictmethod infairseq/models/multilingual_transformer.pyworked for me. I just copied the weights from existing models. I did something like:Please note that it's quick&dirt code :)