Fairseq: How to score text with trained language model

Created on 17 Oct 2019  路  6Comments  路  Source: pytorch/fairseq

I successfully trained a Transformer language model with fairseq. Now I would like to score text with this model.

This is what I am looking for:

echo "Input text to be scored by lm" | fairseq-score trained_model_path/checkpoint_best.pt
78.23 # example language model perplexity score for this sentence

Alternatively, something like

import torch
from fairseq.models.transformer_lm import TransformerLanguageModel

custom_lm = TransformerLanguageModel.from_pretrained('trained_model_path', 'checkpoint_best.pt')
custom_lm.score('Input text to be scored by lm')
# 78.23 # example language model perplexity score for this sentence

Looking here:

https://github.com/pytorch/fairseq/tree/master/examples/language_model

and here:

https://fairseq.readthedocs.io/en/latest/command_line_tools.html#fairseq-eval-lm

it seems that I have to binarize my test data with fairseq-preprocess, which I want to avoid.

What is the easiest way to score plain text with a trained fairseq LM?

documentation enhancement question

All 6 comments

Found a way to do it:

>>> tokens = 'the potentially medically important signs and symptoms'
>>> num_tokens = len(tokens.split(" "))
>>> custom_lm.sample('the potentially medically important signs and symptoms', verbose=True, max_len_b=num_tokens)
S       the potentially medically important signs and symptoms
H       -3.916642665863037      the potentially medically important signs and symptoms
P       -2.5552 -10.4895 -4.0601 -2.6499 -1.0805 -1.8499 -0.0001 -8.6480
'the potentially medically important signs and symptoms'

Which is inconvenient since it just prints the scores to STDOUT directly without returning them.

My solution, without having much insight into torch and fairseq:

import torch
import copy

from fairseq import hub_utils
from fairseq.models.fairseq_model import FairseqLanguageModel


class GeneratorHubInterfaceWithScoring(hub_utils.GeneratorHubInterface):

    def score(self,
              sentence: str,
              verbose: bool = False,
              **kwargs) -> float:

        tokens = sentence.split(" ")
        num_tokens = len(tokens)

        encoded_sentence = self.binarize(sentence)
        sample = self._build_sample(encoded_sentence)

        # build generator using current args as well as any kwargs
        gen_args = copy.copy(self.args)
        gen_args.beam = 1
        gen_args.max_len_b = num_tokens
        for k, v in kwargs.items():
            setattr(gen_args, k, v)
        generator = self.task.build_generator(gen_args)

        translations = self.task.inference_step(generator, self.models, sample)

        hypo = translations[0][0]
        score = hypo['score']

        scored_tokens = hypo['tokens']
        scored_sentence = self.string(scored_tokens)

        assert sentence == scored_sentence, "Input tokens and the ones that are actually scored do not seem identical:\n%s\n%s" % (sentence, scored_sentence)

        if verbose:
            print("TOKENS:\t%s" % scored_tokens)

        return score


class FairseqLanguageModelWithScoring(FairseqLanguageModel):

    @classmethod
    def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', **kwargs):

        x = hub_utils.from_pretrained(
            model_name_or_path,
            checkpoint_file,
            data_name_or_path,
            archive_map=cls.hub_models(),
            **kwargs,
        )

        return GeneratorHubInterfaceWithScoring(x['args'], x['task'], x['models'])

Then:

custom_lm = FairseqLanguageModelWithScoring.from_pretrained(args.model_dir, 'checkpoint_best.pt')

Yeah, the sample interface is just a wrapper around encode, generate and decode, so I recommend using those directly for more control: https://github.com/pytorch/fairseq/blob/master/fairseq/hub_utils.py#L119-L122

@myleott Is scoring a string with fairseq LMs not a frequent use case for you? If it is, could you perhaps post your code somewhere?

My solution above works, but is too slow, for instance because there is no batching.

Good call, we should add better documentation around this and the proposed score function makes sense. I'll add something like this shortly.

Added a .score function in 9d7725226da3fcd9c5d1ac02473289f53cd7dd78. It should be much faster than using generate.

Usage:

en_lm.score('Barack Obama is coming to Sydney and New Zealand')['positional_scores']
Was this page helpful?
0 / 5 - 0 ratings

Related issues

jiezhangGt picture jiezhangGt  路  14Comments

loretoparisi picture loretoparisi  路  91Comments

astariul-colanim picture astariul-colanim  路  14Comments

SunbowLiu picture SunbowLiu  路  22Comments

fengkaineu picture fengkaineu  路  14Comments