Transformers: DistilBERT Loss Function Choice and further query on extending to GPT2.

Created on 2 Sep 2019  ยท  7Comments  ยท  Source: huggingface/transformers

โ“ Questions & Help

Can you describe the motiavtion behind scaling the KLDivLoss by squared temperature ?

https://github.com/huggingface/pytorch-transformers/blob/50792dbdcccd64f61483ec535ff23ee2e4f9e18d/examples/distillation/distiller.py#L331

When applying the same logic for GPT-2 distillation, I did the following

    def training_step(self, data_batch, batch_i):
        """
        Lightning calls this inside the training loop
        :param data_batch:
        :return:
        """
        # forward pass

        token_ids, lengths = data_batch
        orig_loss_ce, s_logits = self.student(input_ids=token_ids, labels=token_ids)[:2]     # (bs, seq_length, voc_size)

        self.teacher.eval() # Required to do this every time.
        with torch.no_grad():
            t_logits = self.teacher(input_ids=token_ids)[0] # (bs, seq_length, voc_size)

        loss_kl = self.kld_loss_fct(F.log_softmax(s_logits/self.temperature, dim=-1),
                                    F.softmax(t_logits/self.temperature, dim=-1)) * (self.temperature)**2
        loss_kl /= s_logits.shape[1]
        loss = self.alpha_kl * loss_kl
        if self.alpha_orig_ce > 0.:
            loss += self.alpha_orig_ce * orig_loss_ce
        if self.alpha_mse > 0.:
            loss_mse = self.mse_loss_fct(s_logits, t_logits)/s_logits.size(0) # Reproducing batchmean reduction
            loss += self.alpha_mse * loss_mse

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp:
            loss = loss.unsqueeze(0)

        output = OrderedDict({
            'loss': loss
        })

        # can also return just a scalar instead of a dict (return loss_val)
        return output


I found that the distillBert implementation lead to high initial loss range for kl (130-180) depending on the average sequence length per batch while cross entropy was in range of 4-5.
So I scaled the loss_kl by the total timesteps in the batch. (My batches don't have any masked tokens). Training did converge to similar perplexity as teacher on the held out set of toronto books.
Is my method motivated, or am I applying the KL wrongly in the GPT2 scenario necessiating the scaling ?

Most helpful comment

Hello @sai-prasanna,
I believe that in the original implementation we release, the Knowledge Distillation loss is batch-averaged meaning that it should not be sensible to the sequence lenghts: self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean'). But anyways, you should just make sure that at the end, if your true loss is batch-size-agnostic, then the knowledge distillation loss should be too.

Regarding your 1st question, the T**2 rescaling simply ensures that both the true loss and the distillation loss are of the same magnitude. You can refer to the original paper, section 2: _"Since the magnitudes of the gradients produced by the soft targets scale as 1/T^2 it is important to multiply them by T^2 when using both hard and soft targets."_
Victor

All 7 comments

Hello @sai-prasanna,
I believe that in the original implementation we release, the Knowledge Distillation loss is batch-averaged meaning that it should not be sensible to the sequence lenghts: self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean'). But anyways, you should just make sure that at the end, if your true loss is batch-size-agnostic, then the knowledge distillation loss should be too.

Regarding your 1st question, the T**2 rescaling simply ensures that both the true loss and the distillation loss are of the same magnitude. You can refer to the original paper, section 2: _"Since the magnitudes of the gradients produced by the soft targets scale as 1/T^2 it is important to multiply them by T^2 when using both hard and soft targets."_
Victor

Thanks!. I will recheck the loss function ranges more carefully. And I guess I jumped ahead without reading the literature carefully, will revisit the papers.

I have a few queries with respect to pre-processing text for the student of GPT-2. (I pm'ed you on twitter, but I guess this place is more accessible to others).

Any guesses on how GPT-2 sequences were sampled for training?

Did they take any random point in the corpus and sampled from there, or would they select a random token (could be in the middle of a sentence) and continue to fill the sequence from that point?

And what of sequence lengths, would they fill up tokens continuously (going across sentence boundaries) till max sequence length? Or would there be variation in sequence lengths and what would be an ideal way to sample the variations?

Thanks!. I will recheck the loss function ranges more carefully. And I guess I jumped ahead without reading the literature carefully, will revisit the papers.

I have a few queries with respect to pre-processing text for the student of GPT-2. (I pm'ed you on twitter, but I guess this place is more accessible to others).

Any guesses on how GPT-2 sequences were sampled for training?

You should refer to the papers GPT and GPT2 (section 2.1/2.2) for a detailed explanation of how the data are processed.

Did they take any random point in the corpus and sampled from there, or would they select a random token (could be in the middle of a sentence) and continue to fill the sequence from that point?

In auto-regressive LM (like GPT* for instance), each token (except the last one in the sequence) induce a training signal by having the model predicting the next token.

And what of sequence lengths, would they fill up tokens continuously (going across sentence boundaries) till max sequence length? Or would there be variation in sequence lengths and what would be an ideal way to sample the variations?

More generally, the longer the sequences, the better it is (that's one of the thing the RoBERTa paper showed). You want to train on as long dependencies as you can.

Thanks. Guess gpt 2 also follows gpt's preprocessing.

I guess my second point was rather unclear. I understand that gpt 2 does traditional lm. I want to know whether inputs to lm while training, strictly start at sentence starts.
"The quick brown cyborg jumped over the lazy sapien. And the cyborg, ..."
Or can inputs be like
"cyborg jumped over the lazy sapien. and the cyborg, ..."
"jumped over the lazy sapien. and the cyborg, ..."

Any hypothesis on how varying training data like that would affect generation? Say one always gives context that start properly, then would there be any gain in not giving sentences that start from middle.

@VictorSanh Experimented with KLDivLoss(reduction='batchmean'). I can confirm that the loss scales with the sequence length.

def test_kl_div_loss(batch, timesteps, hidden, n=10000): 
    loss_fn = nn.KLDivLoss(reduction='batchmean') 
    student_logits = torch.randn(batch, timesteps, hidden) 
    teacher_logits = torch.randn(batch, timesteps, hidden) 
    mean_loss = 0.0 
    for _ in range(n): 
        mean_loss += loss_fn(F.log_softmax(student_logits, dim=-1), F.softmax(teacher_logits, dim=-1)) 
    mean_loss /= n 
    return mean_loss 
In [79]: test_kl_div_loss(batch=10, timesteps=10, hidden=10)                                                                                                                                 
Out[79]: tensor(8.4171)
In [79]: test_kl_div_loss(batch=10, timesteps=100, hidden=10)                                                                                                                                 
Out[79]: tensor(77.5201)
In [83]: test_kl_div_loss(batch=10, timesteps=1000, hidden=10)                                                                                                                               
Out[83]: tensor(807.4752)

nn.KLDivLoss with batchmean is proportional to total timesteps. And reduction=mean is wrong as it averages by the number of classes.

In nn.CrossEntropyLoss we flatten the time dimension to batch and then compute cross entropy, this in effect averages the loss across timesteps as the default reduction is 'mean'.

So ideally, when computing the KL Div, should we ideally set the reduction='none' and scale the loss by ( 1 / total_actual_non_padding_tokens_in_batch ) ?

Thanks. Guess gpt 2 also follows gpt's preprocessing.

I guess my second point was rather unclear. I understand that gpt 2 does traditional lm. I want to know whether inputs to lm while training, strictly start at sentence starts.
"The quick brown cyborg jumped over the lazy sapien. And the cyborg, ..."
Or can inputs be like
"cyborg jumped over the lazy sapien. and the cyborg, ..."
"jumped over the lazy sapien. and the cyborg, ..."

Any hypothesis on how varying training data like that would affect generation? Say one always gives context that start properly, then would there be any gain in not giving sentences that start from middle.

You could do the second option, I am just not sure whether it fundamentally brings significantly more training signal than the 1st option. Thus we usually do the 1st option.
You should have a look at how it is done in GPT/GPT2. Folks at Nvidia have released their pre-processing script for GPT2: see here.

@VictorSanh Experimented with KLDivLoss(reduction='batchmean'). I can confirm that the loss scales with the sequence length.

def test_kl_div_loss(batch, timesteps, hidden, n=10000): 
    loss_fn = nn.KLDivLoss(reduction='batchmean') 
    student_logits = torch.randn(batch, timesteps, hidden) 
    teacher_logits = torch.randn(batch, timesteps, hidden) 
    mean_loss = 0.0 
    for _ in range(n): 
        mean_loss += loss_fn(F.log_softmax(student_logits, dim=-1), F.softmax(teacher_logits, dim=-1)) 
    mean_loss /= n 
    return mean_loss 
In [79]: test_kl_div_loss(batch=10, timesteps=10, hidden=10)                                                                                                                                 
Out[79]: tensor(8.4171)
In [79]: test_kl_div_loss(batch=10, timesteps=100, hidden=10)                                                                                                                                 
Out[79]: tensor(77.5201)
In [83]: test_kl_div_loss(batch=10, timesteps=1000, hidden=10)                                                                                                                               
Out[83]: tensor(807.4752)

nn.KLDivLoss with batchmean is proportional to total timesteps. And reduction=mean is wrong as it averages by the number of classes.

In nn.CrossEntropyLoss we flatten the time dimension to batch and then compute cross entropy, this in effect averages the loss across timesteps as the default reduction is 'mean'.

So ideally, when computing the KL Div, should we ideally set the reduction='none' and scale the loss by ( 1 / total_actual_non_padding_tokens_in_batch ) ?

What I simply do in the training code is a student_logits.view(-1, hidden) so that at the end, it is sequence-length and batch size agnostic (see here for instance)

Thanks for taking your time to answer all my queries.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

lemonhu picture lemonhu  ยท  3Comments

iedmrc picture iedmrc  ยท  3Comments

hsajjad picture hsajjad  ยท  3Comments

fyubang picture fyubang  ยท  3Comments

HanGuo97 picture HanGuo97  ยท  3Comments