Hello! First of all, thanks for this awesome package!
I'm training a text classifier as described in the tutorial. Since my texts are quite short, I'm using the sentence transformer for embedding:
from torch.optim.adam import Adam
from flair.embeddings import SentenceTransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
label_dict = corpus.make_label_dictionary()
embedding = SentenceTransformerDocumentEmbeddings('bert-base-nli-stsb-mean-tokens')
classifier = TextClassifier(embedding, label_dictionary=label_dict)
trainer = ModelTrainer(classifier, corpus, optimizer=Adam)
trainer.train('trained_models',
learning_rate=3e-5, # use very small learning rate
mini_batch_size=16,
mini_batch_chunk_size=4, # optionally set this if transformer is too much for your machine
max_epochs=1000,
)
The model shows a good performance for our data. However, I realized that only the classifier layer is trained, and the transfomer is stayed untouched during the training:
import torch
orig_layers = SentenceTransformerDocumentEmbeddings('bert-base-nli-stsb-mean-tokens').model[0].bert.encoder.layer
for idx, layer in enumerate(classifier.document_embeddings.model[0].bert.encoder.layer):
orig_layer = orig_layers[idx]
assert torch.allclose(layer.attention.self.query.weight, orig_layer.attention.self.query.weight)
assert torch.allclose(layer.attention.self.key.weight, orig_layer.attention.self.key.weight)
assert torch.allclose(layer.attention.self.value.weight, orig_layer.attention.self.value.weight)
assert torch.allclose(layer.attention.output.dense.weight, orig_layer.attention.output.dense.weight)
assert torch.allclose(layer.intermediate.dense.weight, orig_layer.intermediate.dense.weight)
assert torch.allclose(layer.output.dense.weight, orig_layer.output.dense.weight)
# (All assertions pass)
After seeking the implementation, I see that sentence transformer even doesn't have fine_tune parameter:
https://github.com/flairNLP/flair/blob/17fa344e27c8af6454a81c9e74c9a57d2d5a94e4/flair/embeddings/document.py#L520-L530
and static_embeddings parameter is explicitly set to True:
https://github.com/flairNLP/flair/blob/17fa344e27c8af6454a81c9e74c9a57d2d5a94e4/flair/embeddings/document.py#L550
So, why SentenceTransformerDocumentEmbeddings are not fine-tunable? Is it because they are already fine-tuned for sentence embedding on texts like the STS dataset (i.e. as described in their paper)? So, is it known to be not good to fine-tune already fine-tuned transformers?
Or you have some other specific reason for staying this kind of embedding as static only?
As I can see, in #1492 it is announced that all transformers are now tunable in this library. But only SentenceTransformerDocumentEmbeddings are not.
Thanks in advance!!!
Hello @ilya-palachev I am not sure if sentence transformers can be further fine-tuned and if this makes sense. @nreimers can you comment?
Hi @alanakbik @ilya-palachev
Yes, sentence transformers could be further fine tuned. It is basically a PyTorch Sequential Model (https://pytorch.org/docs/master/generated/torch.nn.Sequential.html) that first calls a BERT (etc.) model and then performs a mean pooling operation. If the forward function of SentenceTransformers is used, you would get gradients for the weights in BERT and BERT would be updated.
Would it make sense to fine-tune them? If you have enough training data, I think it would make sense.
By the way, most models are available from us in the huggingface repository:
https://huggingface.co/sentence-transformers
I see there is in flair a TransformerDocumentEmbedding, so you could try this:
document_embeddings = TransformerDocumentEmbeddings('sentence-transformers/bert-base-nli-mean-tokens', fine_tune=True)
This would load our sentence-transformers bert-base-nli-mean-tokens models. It loads this model without any pooling layer.
I am not sure what pooling strategy TransformerDocumentEmbeddings uses? Does it use mean pooling or does it use the CLS token as embedding?
If it uses the CLS token as embedding, than this would be the right model:
https://huggingface.co/sentence-transformers/bert-base-nli-cls-token
Best
Nils Reimers
Ah interesting, thanks! Yes, the TransformerDocumentEmbeddings class uses the CLS token or equivalent.
So for the CLS sentence transformers, I guess we actually don't need a separate class. For the others, we would need to add a mean pooling layer, then we could have all transformers in one class, right?
Hi @alanakbik
Yes, adding mean pooling would be quite nice to the TransformerDocumentEmbeddings class.
See here how to do this with minimal code for the HF AutoModel:
https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens
Sometimes max pooling is also quite nice. Here you can find the code how to do max pooling with HF AutoModel:
https://huggingface.co/sentence-transformers/bert-base-nli-max-tokens
The pooling mechanism could be added as a parameter to the TransformerDocumentEmbeddings class.
Best
Nils
@nreimers thanks for the info - we'll get right on it :)
Hi! Was mean pooling ever added in the recent releases? I was curious to try it 馃槂
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.