Flair: Why SentenceTransformerDocumentEmbeddings are not fine-tunable?

Created on 21 Jul 2020  路  7Comments  路  Source: flairNLP/flair

Hello! First of all, thanks for this awesome package!

I'm training a text classifier as described in the tutorial. Since my texts are quite short, I'm using the sentence transformer for embedding:

from torch.optim.adam import Adam

from flair.embeddings import SentenceTransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer

label_dict = corpus.make_label_dictionary()
embedding = SentenceTransformerDocumentEmbeddings('bert-base-nli-stsb-mean-tokens')
classifier = TextClassifier(embedding, label_dictionary=label_dict)
trainer = ModelTrainer(classifier, corpus, optimizer=Adam)
trainer.train('trained_models',
              learning_rate=3e-5, # use very small learning rate
              mini_batch_size=16,
              mini_batch_chunk_size=4, # optionally set this if transformer is too much for your machine
              max_epochs=1000,
              )

The model shows a good performance for our data. However, I realized that only the classifier layer is trained, and the transfomer is stayed untouched during the training:

import torch
orig_layers = SentenceTransformerDocumentEmbeddings('bert-base-nli-stsb-mean-tokens').model[0].bert.encoder.layer
for idx, layer in enumerate(classifier.document_embeddings.model[0].bert.encoder.layer):
    orig_layer = orig_layers[idx]
    assert torch.allclose(layer.attention.self.query.weight, orig_layer.attention.self.query.weight)
    assert torch.allclose(layer.attention.self.key.weight, orig_layer.attention.self.key.weight)
    assert torch.allclose(layer.attention.self.value.weight, orig_layer.attention.self.value.weight)
    assert torch.allclose(layer.attention.output.dense.weight, orig_layer.attention.output.dense.weight)
    assert torch.allclose(layer.intermediate.dense.weight, orig_layer.intermediate.dense.weight)
    assert torch.allclose(layer.output.dense.weight, orig_layer.output.dense.weight)
# (All assertions pass)

After seeking the implementation, I see that sentence transformer even doesn't have fine_tune parameter:
https://github.com/flairNLP/flair/blob/17fa344e27c8af6454a81c9e74c9a57d2d5a94e4/flair/embeddings/document.py#L520-L530
and static_embeddings parameter is explicitly set to True:
https://github.com/flairNLP/flair/blob/17fa344e27c8af6454a81c9e74c9a57d2d5a94e4/flair/embeddings/document.py#L550

So, why SentenceTransformerDocumentEmbeddings are not fine-tunable? Is it because they are already fine-tuned for sentence embedding on texts like the STS dataset (i.e. as described in their paper)? So, is it known to be not good to fine-tune already fine-tuned transformers?

Or you have some other specific reason for staying this kind of embedding as static only?

As I can see, in #1492 it is announced that all transformers are now tunable in this library. But only SentenceTransformerDocumentEmbeddings are not.

Thanks in advance!!!

question wontfix

All 7 comments

Hello @ilya-palachev I am not sure if sentence transformers can be further fine-tuned and if this makes sense. @nreimers can you comment?

Hi @alanakbik @ilya-palachev

Yes, sentence transformers could be further fine tuned. It is basically a PyTorch Sequential Model (https://pytorch.org/docs/master/generated/torch.nn.Sequential.html) that first calls a BERT (etc.) model and then performs a mean pooling operation. If the forward function of SentenceTransformers is used, you would get gradients for the weights in BERT and BERT would be updated.

Would it make sense to fine-tune them? If you have enough training data, I think it would make sense.

By the way, most models are available from us in the huggingface repository:
https://huggingface.co/sentence-transformers

I see there is in flair a TransformerDocumentEmbedding, so you could try this:

document_embeddings = TransformerDocumentEmbeddings('sentence-transformers/bert-base-nli-mean-tokens', fine_tune=True)

This would load our sentence-transformers bert-base-nli-mean-tokens models. It loads this model without any pooling layer.

I am not sure what pooling strategy TransformerDocumentEmbeddings uses? Does it use mean pooling or does it use the CLS token as embedding?

If it uses the CLS token as embedding, than this would be the right model:
https://huggingface.co/sentence-transformers/bert-base-nli-cls-token

Best
Nils Reimers

Ah interesting, thanks! Yes, the TransformerDocumentEmbeddings class uses the CLS token or equivalent.

So for the CLS sentence transformers, I guess we actually don't need a separate class. For the others, we would need to add a mean pooling layer, then we could have all transformers in one class, right?

Hi @alanakbik

Yes, adding mean pooling would be quite nice to the TransformerDocumentEmbeddings class.

See here how to do this with minimal code for the HF AutoModel:
https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens

Sometimes max pooling is also quite nice. Here you can find the code how to do max pooling with HF AutoModel:
https://huggingface.co/sentence-transformers/bert-base-nli-max-tokens

The pooling mechanism could be added as a parameter to the TransformerDocumentEmbeddings class.

Best
Nils

@nreimers thanks for the info - we'll get right on it :)

Hi! Was mean pooling ever added in the recent releases? I was curious to try it 馃槂

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

prematurelyoptimized picture prematurelyoptimized  路  3Comments

jannenev picture jannenev  路  3Comments

frtacoa picture frtacoa  路  3Comments

mnishant2 picture mnishant2  路  3Comments

mittalsuraj18 picture mittalsuraj18  路  3Comments