Hi,
I am fine-tuning BERT model (based on BertForTokenClassification
) to a NER task with 9 labels ("O" + BILU tags for 2 classes) and sometimes during training I run into this odd behavior: a network with 99% accuracy that is showing a converging trend suddenly shifts all of its predictions to a single class. This happens during the interval of a single epoch.
Below are the confusion matrices and some other metrics one epoch before the event and after the event:
Epoch 7/10: 150.57s/it, val_acc=99.718% (53391/53542), val_acc_bilu=87.568% (162/185), val_rec=98.780%, val_prec=55.862%, val_f1=71.366%
Confusion matrix:
[[53229 2 66 25 2 25 8]
[ 0 7 0 7 0 0 0]
[ 0 0 14 0 0 0 0]
[ 0 0 0 67 0 0 1]
[ 1 0 0 3 11 0 1]
[ 1 0 1 0 0 14 0]
[ 0 0 0 7 1 0 49]]
Epoch 8/10: 150.64s/it, val_acc=0.030% (16/53542), val_acc_bilu=8.649% (16/185), val_rec=100.000%, val_prec=0.030%, val_f1=0.060%
Confusion matrix:
[[ 0 0 0 0 53357 0 0]
[ 0 0 0 0 14 0 0]
[ 0 0 0 0 14 0 0]
[ 0 0 0 0 68 0 0]
[ 0 0 0 0 16 0 0]
[ 0 0 0 0 16 0 0]
[ 0 0 0 0 57 0 0]]
I am using the default configs for bert-base-multilingual-cased
and standard CrossEntropyLoss
. The optimizer is BertAdam
untouched with learning rate 1e-5. The dataset is highly unbalanced (very few named entities, so >99% of the tokens are "O" tags), so I use a weight of 0.01 to the "O" tag in CE.
Has anyone faced a similar issue?
Thanks in advance
I manage to solve this problem. There is an issue in the calculation of the total optimization steps in run_squad.py
example that results in a negative learning rate because of the warmup_linear
schedule. This happens because t_total
is calculated based on len(train_examples)
instead of len(train_features)
. That may not be a problem for datasets with short sentences, but, for long sentences, one example may generate many entries in train_features
due to the strategy of dividing an example in DocSpan's
.
@fabiocapsouza I am trying to handle text classification but my dataset is also highly unbalanced. I am trying to find where I can adjust the class weights when training transformers. Which parameter you changed in your case?
@MendesSP , since the provided BERT model classes have the loss function hardcoded in the forward
method, I had to write a subclass to override the CrossEntropyLoss
definition passing a weight
tensor.
Most helpful comment
@MendesSP , since the provided BERT model classes have the loss function hardcoded in the
forward
method, I had to write a subclass to override theCrossEntropyLoss
definition passing aweight
tensor.