Transformers: [generation] multiple eos/pad asserts/ifs in generate search functions

Created on 7 Sep 2020  路  7Comments  路  Source: huggingface/transformers

In _generate_no_beam_search eos_token_id is required: https://github.com/huggingface/transformers/blob/master/src/transformers/generation_utils.py#L731 (that code always get hit)

                    assert (
                        eos_token_id is not None and pad_token_id is not None
                    ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"

why do we assert and check eos_token_id is not None multiple times through the code, why not assert once at the top of the function and then just use it?

Moreover, all those if eos_token_id is not None can be then removed (or reduced if there are other parts to them).

Also a larger question - is there a model where eos_token_id is not defined? If there is none, then why not assert once at the top of generate and then just use it everywhere in sub-calls without testing its definition?

Oh, I also see pad_token_id is used in _generate_no_beam_search w/o testing whether it's defined: https://github.com/huggingface/transformers/blob/master/src/transformers/generation_utils.py#L571

                tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)

is it the same situation as eos_token_id - that is it is always needed?

I see it's may be defined here: https://github.com/huggingface/transformers/blob/master/src/transformers/generation_utils.py#L355 but only if eos_token_id is defined.

        if pad_token_id is None and eos_token_id is not None:
            logger.warning(
                "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
            )
            pad_token_id = eos_token_id

my thinking is that if this worked until now for all models, it's another proof that eos_token_id has to be required again.

in _generate_no_beam_search pad_token_id is required and similarly to eos_token_id can be asserted once on top and not multiple times through the code.

Thank you for reviewing my observations. It's possible that some (all?) are incorrect if I missed something.

All 7 comments

I think eos is always defined, but I think this (or just checking pad_token_id was one of Patrick's first PRs. He would know more.

Thank you for that feedback, @sshleifer.

If it makes things simpler, I could re-work both functions wrt these 2 tokens' definition checks and you can review a PR instead.

I just wanted to validate that the issue is real and I'm not missing something obvious before I invest time into doing that.

Hey @stas00,

This is definitely part of the code that should be refactored :D Super hard to follow the logic there :-/

As a start, this PR is probably quite useful for context: https://github.com/huggingface/transformers/pull/2885. So there are a couple of models where EOS token is not defined and I'm quite sure that the code you linked does not always get hit. It can very well be that we apply beam search to OpenAIGPT - with a given max_length. OpenAIGPT does not have an EOS token, but beam search should work nevertheless.

It's quite a tricky pad token / eos token / ... logic that is implemented there. I think we have to be super careful to not break anything here - even if all the slow tests pass, it might not be enough (OpenAIGPT beam search is not integration tested...)

Also, I'm currently working on refactoring the generate function, will ping you guys in a couple of days with a first design proposition. My idea is to pull apart beam search + greedy / beam search + sampling / no beam search + greedy / no beam searh + greedy to make everything more readable. I'm not sure whether it's worth diving deep into the generate() logic before we have a more readable code

That sounds like a fantastic plan, @patrickvonplaten!

So there are a couple of models where EOS token is not defined and I'm quite sure that the code you linked does not always get hit.

I stand corrected, that's good to know, thank you!.

That means that the code is very tricky, since a reader will expect that at some point the generation should be complete and done set to True, which currently absolutely requires eos. I haven't considered the case where it'll go through that loop and not hit done. If I follow it carefully it only happens if max_length is reached and there is no done yet, and moreover it has to be that the hypos are exactly of the same length. if they aren't the same, eos is almost always required.

As you are saying there isn't really a test that covers that (odd?) case. Actually, PR https://github.com/huggingface/transformers/pull/6982 is very likely to break it then, since now it requires eos for both situations where hypos are of the same length and are not. But if it breaks that very special case, then the issue lies elsewhere and it just happened to work. (As I suggested I changed "is" for "was" in an input and suddenly eos was gone from all of the hypos.)

Note: I have only run the code in my head and haven't validated that in fact it'd break something. It's possible that you're talking about a completely different case.

I think your PR is fine because if no eos_token_id is defined, this condition can never happen: sent_lengths[i] < max_length:.
What I mean is that if no eos_token_id is defined no matter what generate() method is used, all sent_length will always be == max_length and the condition will not be hit.

ah, yes, you're absolutely correct, Patrick - you definitely have been holding that generation code in your head for much longer than I - I don't have the full coverage yet :)

Reopen if this was a mistake!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

guanlongtianzi picture guanlongtianzi  路  3Comments

yspaik picture yspaik  路  3Comments

rsanjaykamath picture rsanjaykamath  路  3Comments

fabiocapsouza picture fabiocapsouza  路  3Comments

alphanlp picture alphanlp  路  3Comments