Transformers: multitask learning

Created on 17 Nov 2019  ยท  4Comments  ยท  Source: huggingface/transformers

โ“ Questions & Help


Hi, I would like to apply multitask learning, using for example two tasks, and the one (and maybe both) of the two tasks is sequential labeling like NER. To my understanding, in order to apply the library on such a task, the way to go is with BertForTokenClassification. Is that correct? But also what I think is that I do not have enough flexibility to use it/ adapt it to create a multitask model.

Could you share your thoughts on that? Any help would be much appreciated.

Most helpful comment

Hi, there are several things you can do to obtain the last layer representation.

  • First of all, you can use a standard BertModel on which you add your own classifier for token classification (that's what's done with BertForTokenClassification). This will allow you to easily switch the heads for your multi-task setup.

  • You could also use the BertForTokenClassification, as you have said, and use the inner model (model.bert) to obtain the last layer.

  • Finally, the cleanest way would be to output hidden states directly by specifying the option in the configuration:

from transformers import BertConfig, BertForTokenClassification
import torch

config = BertConfig.from_pretrained("bert-base-cased")
config.output_hidden_states = True

model = BertForTokenClassification.from_pretrained("bert-base-cased", config=config)

inputs = torch.tensor([[1, 2, 3]])

outputs = model(inputs)
token_classification_outputs, hidden_states = outputs

last_layer_hidden_states = hidden_states[-1]

The variable last_layer_hidden_states is of shape [batch_size, seq_len, hidden_size] and is the output of the last transformer layer.

I hope this clears things up.

All 4 comments

from the documentation of the class,
https://github.com/huggingface/transformers/blob/933841d903a032d93b5100220dc72db9d1283eca/pytorch_transformers/modeling_bert.py#L1100

I understand that I could use the scores, as input to an additional module that I could stack on top of BertForTokenClassification, for example for a second task. Is that correct?

loss, scores = outputs[:2]

But what I am thinking right now, is that scores could have small dimensions, so probably I would need the weights of the last layer. How could I extract them?

Always your thoughts on that would be much appreciated!

for every other suffered person that needs an answer on that last question above, I think that the way to extract those weights, is if you open the black box of the implementation of this class, and here it is what you want:
outputs = self.bert(..)
so I think that reimplementation/enhancment is needed to support my needs as was given above.
Am I missing something if I reimplement this class adding more functionality? I think that no, and that it is safe.

Hi, there are several things you can do to obtain the last layer representation.

  • First of all, you can use a standard BertModel on which you add your own classifier for token classification (that's what's done with BertForTokenClassification). This will allow you to easily switch the heads for your multi-task setup.

  • You could also use the BertForTokenClassification, as you have said, and use the inner model (model.bert) to obtain the last layer.

  • Finally, the cleanest way would be to output hidden states directly by specifying the option in the configuration:

from transformers import BertConfig, BertForTokenClassification
import torch

config = BertConfig.from_pretrained("bert-base-cased")
config.output_hidden_states = True

model = BertForTokenClassification.from_pretrained("bert-base-cased", config=config)

inputs = torch.tensor([[1, 2, 3]])

outputs = model(inputs)
token_classification_outputs, hidden_states = outputs

last_layer_hidden_states = hidden_states[-1]

The variable last_layer_hidden_states is of shape [batch_size, seq_len, hidden_size] and is the output of the last transformer layer.

I hope this clears things up.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

HanGuo97 picture HanGuo97  ยท  3Comments

siddsach picture siddsach  ยท  3Comments

lcswillems picture lcswillems  ยท  3Comments

iedmrc picture iedmrc  ยท  3Comments

hsajjad picture hsajjad  ยท  3Comments