Fairseq: Reproducing RoBERTa STS-B results

Created on 12 Aug 2019  路  7Comments  路  Source: pytorch/fairseq

Thanks for your great work! Quick request: could you please add how to replicate RoBERTa STS-B results? There are good instructions for fine-tuning the model but I didn't see any for running the prediction task itself. Thanks!

All 7 comments

You can use something like following code-snippet for inference on STS-B.

from fairseq.models.roberta import RobertaModel
from scipy.stats import pearsonr

roberta = RobertaModel.from_pretrained(
    'checkpoints/',
    checkpoint_file='checkpoint_best.pt',
    data_name_or_path='STS-B-bin'
)

roberta.cuda()
roberta.eval()
gold, pred = [], []
with open('glue_data/STS-B/dev.tsv') as fin:
    fin.readline()
    for index, line in enumerate(fin):
        tokens = line.strip().split('\t')
        sent1, sent2, target = tokens[7], tokens[8], float(tokens[9])
        tokens = roberta.encode(sent1, sent2)
        features = roberta.extract_features(tokens)
        predictions = 5.0 * roberta.model.classification_heads['sentence_classification_head'](features)
        gold.append(target)
        pred.append(predictions.item())

print('| Pearson: ', pearsonr(gold, pred))

I will add above to readme, thanks for pointing it out.

Thanks so much! Unfortunately I wasn't able to reproduce your STS-B score. The value of predictions.item() was always ~2.84, changing only very slightly for each example. I followed your finetuning instructions for STS-B and didn't change the above code for inference. Any thoughts on what might be going on?

hmm can you please share your training logs? specially what loss you are getting?
Also can you please make sure you see "loaded checkpoint ..." file in the training logs?

You're right, I wasn't loading the checkpoint. Fixed that and now I'm able to reproduce STS-B. Thanks again!

Where do we get the checkpoint file?

predictions = 5.0 * roberta.model.classification_heads'sentence_classification_head'

Why do we need to multiply 5.0 to the output of the model?

why we need data_name_or_path as STS-bin

Was this page helpful?
0 / 5 - 0 ratings