Transformers: output generate scores per hypothesis/token

Created on 21 Jun 2020  路  12Comments  路  Source: huggingface/transformers

馃殌 Feature request

Thanks for doing such an awesome work.
i'm interested in the hypothesis score when running generate.
This could be done per hypothesis, or preferably per token in the hypothesis.

Motivation

The motivation is to gain confidence for my generated text,

I suggest:

  1. adding flag in modeling_utils.py to generate to return_scores
  2. in _generate_beam_search :
    if return_scores:
    return also
    for _generate_beam_search:

` best_scores = []

    # retrieve best hypotheses
    for i, hypotheses in enumerate(generated_hyps):
        sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
        for j in range(output_num_return_sequences_per_batch):
            effective_batch_idx = output_num_return_sequences_per_batch * i + j
            hyp_score, best_hyp = sorted_hyps.pop()
            sent_lengths[effective_batch_idx] = len(best_hyp)
            best.append(best_hyp)
            best_scores.append(hyp_score)

    # shorter batches are filled with pad_token
    if sent_lengths.min().item() != sent_lengths.max().item():
        assert pad_token_id is not None, "`Pad_token_id` has to be defined"
        sent_max_len = min(sent_lengths.max().item() + 1, max_length)
        decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)

        # fill with hypothesis and eos_token_id if necessary
        for i, hypo in enumerate(best):
            decoded[i, : sent_lengths[i]] = hypo
            if sent_lengths[i] < max_length:
                decoded[i, sent_lengths[i]] = eos_token_id
    else:
        # none of the hypotheses have an eos_token
        assert (len(hypo) == max_length for hypo in best)
        decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)

    if return_scores:
        return decoded, best_scores  `

for _generate_no_beam_search:
` output_score = 0

    while cur_len < max_length:
        model_inputs = self.prepare_inputs_for_generation(
            input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
        )

        outputs = self(**model_inputs)
        next_token_logits = outputs[0][:, -1, :]

        # if model has past, then set the past variable to speed up decoding
        if self._use_cache(outputs, use_cache):
            past = outputs[1]

        # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
        if repetition_penalty != 1.0:
            self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)

        if no_repeat_ngram_size > 0:
            # calculate a list of banned tokens to prevent repetitively generating the same ngrams
            # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
            banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
            for batch_idx in range(batch_size):
                next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")

        if bad_words_ids is not None:
            # calculate a list of banned tokens according to bad words
            banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)

            for batch_idx in range(batch_size):
                next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")

        # set eos token prob to zero if min_length is not reached
        if eos_token_id is not None and cur_len < min_length:
            next_token_logits[:, eos_token_id] = -float("inf")

        if do_sample:
            # Temperature (higher temperature => more likely to sample low probability tokens)
            if temperature != 1.0:
                next_token_logits = next_token_logits / temperature
            # Top-p/top-k filtering
            next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            # Sample
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
        else:
            # Greedy decoding
            next_token = torch.argmax(next_token_logits, dim=-1)
        from IPython import embed; embed()
        next_score = torch.gather(next_token_logits, -1, next_tokens)  # (batch_size, num_beams * 2)
        # update generations and finished sentences
        if eos_token_id is not None:
            # pad finished sentences if eos_token_id exist
            tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
        else:
            tokens_to_add = next_token

        # add token and increase length by one
        input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
        output_score+=next_score
        cur_len = cur_len + 1

        if eos_token_id is not None:
            eos_in_sents = tokens_to_add == eos_token_id
            # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
            is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
            sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
            # unfinished_sents is set to zero if eos in sentence
            unfinished_sents.mul_((~eos_in_sents).long())

        # stop when there is a </s> in each sentence, or if we exceed the maximul length
        if unfinished_sents.max() == 0:
            break

        # extend attention_mask for new generated input if only decoder
        if self.config.is_encoder_decoder is False:
            attention_mask = torch.cat(
                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
            )

    # if there are different sentences lengths in the batch, some batches have to be padded
    if sent_lengths.min().item() != sent_lengths.max().item():
        assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
        # finished sents are filled with pad_token
        decoded = input_ids.new(batch_size, sent_lengths.max().item()).fill_(pad_token_id)
    else:
        decoded = input_ids

    for hypo_idx, hypo in enumerate(input_ids):
        decoded[hypo_idx, : sent_lengths[hypo_idx]] = hypo[: sent_lengths[hypo_idx]]    

    if return_scores:
          return decoded,  output_score
    return decoded `

In the next step we could save the score per token to allow the user to decide where he wants to truncate the generated text as function of confidence

Most helpful comment

Dear @patrickvonplaten and @sshleifer, Thanks for the quick reply.
I'm interested in the perplexity of my generated text as function of different generated methods. This can done using the probabilities of the output tokens.
Another interesting case that jumps to mind is the case of auto complete, where you wanna present the user a generated text only if it passes some threshold of confidence.

All 12 comments

I guess similar to output_attentions and output_hidden_states, we could output the scores / probabilities for generation, but I'm really not sure if it is required that often. What do you think @sshleifer @yjernite ?

I would suggest trying it on a branch and seeing if it produces better generations. I have been inspecting the scores this week (just by saving hypotheses to disk) and have not gotten much utility. If it helps produce better generations, however, we should obviously add this!

Dear @patrickvonplaten and @sshleifer, Thanks for the quick reply.
I'm interested in the perplexity of my generated text as function of different generated methods. This can done using the probabilities of the output tokens.
Another interesting case that jumps to mind is the case of auto complete, where you wanna present the user a generated text only if it passes some threshold of confidence.

Those are actually very useful applications! We will soon have a bigger refactoring of the generate method I think and will hopefully include this.

As @sshleifer said, for now, it would be great if you can show how you would integrate it on a branch including some interesting results.

Fantastic. Will do.

Thanks for raising the issue @guyeyal. IT would definitely be helpful to have a running example.

More generally @patrickvonplaten I think this is functionality will be helpful for the line of research concerned with analyzing the role of preplexity as a training objective as well as work on re-ranking generations or using stuff like noisy channel modeling, so definitely think it should be in the next big refactor.

https://arxiv.org/abs/1904.09751
https://arxiv.org/abs/1908.05731

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

I'll add a vote here that I'm interested in this too. I wrote some code locally very similar to guyeyal's.

Thanks for the great work!

I would also be interested in this functionality. I am using an autoregressive transformer model as part of a reinforcement learning problem. To alleviate the sample inefficiency of RL, it is very attractive to generate data using beam search, in order to add num_beams > 1 of data to a buffer per time step. I would then like to bias the sampling of data from this buffer according to the probability of the generated sequence, defined like the diagram in this example:

https://huggingface.co/blog/how-to-generate#beam-search

@patrickvonplaten is this something that is likely to be covered in the PR here: https://github.com/huggingface/transformers/pull/6949
or is it better to open a new issue? Thanks!

There seems to be a lot of interest in this functionality! If someone feels like opening a PR that would be great!

There seems to be a lot of interest in this functionality! If someone feels like opening a PR that would be great!

I saw a PR here, but not committed. #6289

There seems to be a lot of interest in this functionality! If someone feels like opening a PR that would be great!

I saw a PR here, but not committed. #6289

Any idea why this wasn't commited?

Was this page helpful?
0 / 5 - 0 ratings

Related issues

iedmrc picture iedmrc  路  3Comments

HansBambel picture HansBambel  路  3Comments

HanGuo97 picture HanGuo97  路  3Comments

siddsach picture siddsach  路  3Comments

zhezhaoa picture zhezhaoa  路  3Comments