Hello covid-19 survivors,
I have been trying to use GPT for token classification, however currently there is none from hugging face, hence I copied your code from berttokenclassification and stitched the below code. But it says all the weights not initialized. Did I make a mistake, please help me!!!
`
class GPTClassifier(OpenAIGPTPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = 2
self.gpt = OpenAIGPTModel(config)
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(768, 2)
self.init_weights()
@add_start_docstrings(
"""GPT Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
outputs = self.gpt(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), scores, (hidden_states), (attentions)
`
weights not initalized:
Weights of GPTClassifier not initialized from pretrained model: ['gpt.tokens_embed.weight', 'gpt.positions_embed.weight', 'gpt.h.0.attn.bias', 'gpt.h.0.attn.c_attn.weight', 'gpt.h.0.attn.c_attn.bias', 'gpt.h.0.attn.c_proj.weight', 'gpt.h.0.attn.c_proj.bias', 'gpt.h.0.ln_1.weight', 'gpt.h.0.ln_1.bias', 'gpt.h.0.mlp.c_fc.weight', 'gpt.h.0.mlp.c_fc.bias', 'gpt.h.0.mlp.c_proj.weight', 'gpt.h.0.mlp.c_proj.bias', 'gpt.h.0.ln_2.weight', 'gpt.h.0.ln_2.bias', 'gpt.h.1.attn.bias', 'gpt.h.1.attn.c_attn.weight', 'gpt.h.1.attn.c_attn.bias', 'gpt.h.1.attn.c_proj.weight', 'gpt.h.1.attn.c_proj.bias', 'gpt.h.1.ln_1.weight', 'gpt.h.1.ln_1.bias', 'gpt.h.1.mlp.c_fc.weight', 'gpt.h.1.mlp.c_fc.bias', 'gpt.h.1.mlp.c_proj.weight', 'gpt.h.1.mlp.c_proj.bias', 'gpt.h.1.ln_2.weight', 'gpt.h.1.ln_2.bias', 'gpt.h.2.attn.bias', 'gpt.h.2.attn.c_attn.weight', 'gpt.h.2.attn.c_attn.bias', 'gpt.h.2.attn.c_proj.weight', 'gpt.h.2.attn.c_proj.bias', 'gpt.h.2.ln_1.weight', 'gpt.h.2.ln_1.bias', 'gpt.h.2.mlp.c_fc.weight', 'gpt.h.2.mlp.c_fc.bias', 'gpt.h.2.mlp.c_proj.weight', 'gpt.h.2.mlp.c_proj.bias', 'gpt.h.2.ln_2.weight', 'gpt.h.2.ln_2.bias', 'gpt.h.3.attn.bias', 'gpt.h.3.attn.c_attn.weight', 'gpt.h.3.attn.c_attn.bias', 'gpt.h.3.attn.c_proj.weight', 'gpt.h.3.attn.c_proj.bias', 'gpt.h.3.ln_1.weight', 'gpt.h.3.ln_1.bias', 'gpt.h.3.mlp.c_fc.weight', 'gpt.h.3.mlp.c_fc.bias', 'gpt.h.3.mlp.c_proj.weight', 'gpt.h.3.mlp.c_proj.bias', 'gpt.h.3.ln_2.weight', 'gpt.h.3.ln_2.bias', 'gpt.h.4.attn.bias', 'gpt.h.4.attn.c_attn.weight', 'gpt.h.4.attn.c_attn.bias', 'gpt.h.4.attn.c_proj.weight', 'gpt.h.4.attn.c_proj.bias', 'gpt.h.4.ln_1.weight', 'gpt.h.4.ln_1.bias', 'gpt.h.4.mlp.c_fc.weight', 'gpt.h.4.mlp.c_fc.bias', 'gpt.h.4.mlp.c_proj.weight', 'gpt.h.4.mlp.c_proj.bias', 'gpt.h.4.ln_2.weight', 'gpt.h.4.ln_2.bias', 'gpt.h.5.attn.bias', 'gpt.h.5.attn.c_attn.weight', 'gpt.h.5.attn.c_attn.bias', 'gpt.h.5.attn.c_proj.weight', 'gpt.h.5.attn.c_proj.bias', 'gpt.h.5.ln_1.weight', 'gpt.h.5.ln_1.bias', 'gpt.h.5.mlp.c_fc.weight', 'gpt.h.5.mlp.c_fc.bias', 'gpt.h.5.mlp.c_proj.weight', 'gpt.h.5.mlp.c_proj.bias', 'gpt.h.5.ln_2.weight', 'gpt.h.5.ln_2.bias', 'gpt.h.6.attn.bias', 'gpt.h.6.attn.c_attn.weight', 'gpt.h.6.attn.c_attn.bias', 'gpt.h.6.attn.c_proj.weight', 'gpt.h.6.attn.c_proj.bias', 'gpt.h.6.ln_1.weight', 'gpt.h.6.ln_1.bias', 'gpt.h.6.mlp.c_fc.weight', 'gpt.h.6.mlp.c_fc.bias', 'gpt.h.6.mlp.c_proj.weight', 'gpt.h.6.mlp.c_proj.bias', 'gpt.h.6.ln_2.weight', 'gpt.h.6.ln_2.bias', 'gpt.h.7.attn.bias', 'gpt.h.7.attn.c_attn.weight', 'gpt.h.7.attn.c_attn.bias', 'gpt.h.7.attn.c_proj.weight', 'gpt.h.7.attn.c_proj.bias', 'gpt.h.7.ln_1.weight', 'gpt.h.7.ln_1.bias', 'gpt.h.7.mlp.c_fc.weight', 'gpt.h.7.mlp.c_fc.bias', 'gpt.h.7.mlp.c_proj.weight', 'gpt.h.7.mlp.c_proj.bias', 'gpt.h.7.ln_2.weight', 'gpt.h.7.ln_2.bias', 'gpt.h.8.attn.bias', 'gpt.h.8.attn.c_attn.weight', 'gpt.h.8.attn.c_attn.bias', 'gpt.h.8.attn.c_proj.weight', 'gpt.h.8.attn.c_proj.bias', 'gpt.h.8.ln_1.weight', 'gpt.h.8.ln_1.bias', 'gpt.h.8.mlp.c_fc.weight', 'gpt.h.8.mlp.c_fc.bias', 'gpt.h.8.mlp.c_proj.weight', 'gpt.h.8.mlp.c_proj.bias', 'gpt.h.8.ln_2.weight', 'gpt.h.8.ln_2.bias', 'gpt.h.9.attn.bias', 'gpt.h.9.attn.c_attn.weight', 'gpt.h.9.attn.c_attn.bias', 'gpt.h.9.attn.c_proj.weight', 'gpt.h.9.attn.c_proj.bias', 'gpt.h.9.ln_1.weight', 'gpt.h.9.ln_1.bias', 'gpt.h.9.mlp.c_fc.weight', 'gpt.h.9.mlp.c_fc.bias', 'gpt.h.9.mlp.c_proj.weight', 'gpt.h.9.mlp.c_proj.bias', 'gpt.h.9.ln_2.weight', 'gpt.h.9.ln_2.bias', 'gpt.h.10.attn.bias', 'gpt.h.10.attn.c_attn.weight', 'gpt.h.10.attn.c_attn.bias', 'gpt.h.10.attn.c_proj.weight', 'gpt.h.10.attn.c_proj.bias', 'gpt.h.10.ln_1.weight', 'gpt.h.10.ln_1.bias', 'gpt.h.10.mlp.c_fc.weight', 'gpt.h.10.mlp.c_fc.bias', 'gpt.h.10.mlp.c_proj.weight', 'gpt.h.10.mlp.c_proj.bias', 'gpt.h.10.ln_2.weight', 'gpt.h.10.ln_2.bias', 'gpt.h.11.attn.bias', 'gpt.h.11.attn.c_attn.weight', 'gpt.h.11.attn.c_attn.bias', 'gpt.h.11.attn.c_proj.weight', 'gpt.h.11.attn.c_proj.bias', 'gpt.h.11.ln_1.weight', 'gpt.h.11.ln_1.bias', 'gpt.h.11.mlp.c_fc.weight', 'gpt.h.11.mlp.c_fc.bias', 'gpt.h.11.mlp.c_proj.weight', 'gpt.h.11.mlp.c_proj.bias', 'gpt.h.11.ln_2.weight', 'gpt.h.11.ln_2.bias', 'classifier.weight', 'classifier.bias']
I0612 13:58:30.055731 4472821184 modeling_utils.py:460] Weights from pretrained model not used in GPTClassifier: ['tokens_embed.weight', 'positions_embed.weight', 'h.0.attn.bias', 'h.0.attn.c_attn.weight', 'h.0.attn.c_attn.bias', 'h.0.attn.c_proj.weight', 'h.0.attn.c_proj.bias', 'h.0.ln_1.weight', 'h.0.ln_1.bias', 'h.0.mlp.c_fc.weight', 'h.0.mlp.c_fc.bias', 'h.0.mlp.c_proj.weight', 'h.0.mlp.c_proj.bias', 'h.0.ln_2.weight', 'h.0.ln_2.bias', 'h.1.attn.bias', 'h.1.attn.c_attn.weight', 'h.1.attn.c_attn.bias', 'h.1.attn.c_proj.weight', 'h.1.attn.c_proj.bias', 'h.1.ln_1.weight', 'h.1.ln_1.bias', 'h.1.mlp.c_fc.weight', 'h.1.mlp.c_fc.bias', 'h.1.mlp.c_proj.weight', 'h.1.mlp.c_proj.bias', 'h.1.ln_2.weight', 'h.1.ln_2.bias', 'h.2.attn.bias', 'h.2.attn.c_attn.weight', 'h.2.attn.c_attn.bias', 'h.2.attn.c_proj.weight', 'h.2.attn.c_proj.bias', 'h.2.ln_1.weight', 'h.2.ln_1.bias', 'h.2.mlp.c_fc.weight', 'h.2.mlp.c_fc.bias', 'h.2.mlp.c_proj.weight', 'h.2.mlp.c_proj.bias', 'h.2.ln_2.weight', 'h.2.ln_2.bias', 'h.3.attn.bias', 'h.3.attn.c_attn.weight', 'h.3.attn.c_attn.bias', 'h.3.attn.c_proj.weight', 'h.3.attn.c_proj.bias', 'h.3.ln_1.weight', 'h.3.ln_1.bias', 'h.3.mlp.c_fc.weight', 'h.3.mlp.c_fc.bias', 'h.3.mlp.c_proj.weight', 'h.3.mlp.c_proj.bias', 'h.3.ln_2.weight', 'h.3.ln_2.bias', 'h.4.attn.bias', 'h.4.attn.c_attn.weight', 'h.4.attn.c_attn.bias', 'h.4.attn.c_proj.weight', 'h.4.attn.c_proj.bias', 'h.4.ln_1.weight', 'h.4.ln_1.bias', 'h.4.mlp.c_fc.weight', 'h.4.mlp.c_fc.bias', 'h.4.mlp.c_proj.weight', 'h.4.mlp.c_proj.bias', 'h.4.ln_2.weight', 'h.4.ln_2.bias', 'h.5.attn.bias', 'h.5.attn.c_attn.weight', 'h.5.attn.c_attn.bias', 'h.5.attn.c_proj.weight', 'h.5.attn.c_proj.bias', 'h.5.ln_1.weight', 'h.5.ln_1.bias', 'h.5.mlp.c_fc.weight', 'h.5.mlp.c_fc.bias', 'h.5.mlp.c_proj.weight', 'h.5.mlp.c_proj.bias', 'h.5.ln_2.weight', 'h.5.ln_2.bias', 'h.6.attn.bias', 'h.6.attn.c_attn.weight', 'h.6.attn.c_attn.bias', 'h.6.attn.c_proj.weight', 'h.6.attn.c_proj.bias', 'h.6.ln_1.weight', 'h.6.ln_1.bias', 'h.6.mlp.c_fc.weight', 'h.6.mlp.c_fc.bias', 'h.6.mlp.c_proj.weight', 'h.6.mlp.c_proj.bias', 'h.6.ln_2.weight', 'h.6.ln_2.bias', 'h.7.attn.bias', 'h.7.attn.c_attn.weight', 'h.7.attn.c_attn.bias', 'h.7.attn.c_proj.weight', 'h.7.attn.c_proj.bias', 'h.7.ln_1.weight', 'h.7.ln_1.bias', 'h.7.mlp.c_fc.weight', 'h.7.mlp.c_fc.bias', 'h.7.mlp.c_proj.weight', 'h.7.mlp.c_proj.bias', 'h.7.ln_2.weight', 'h.7.ln_2.bias', 'h.8.attn.bias', 'h.8.attn.c_attn.weight', 'h.8.attn.c_attn.bias', 'h.8.attn.c_proj.weight', 'h.8.attn.c_proj.bias', 'h.8.ln_1.weight', 'h.8.ln_1.bias', 'h.8.mlp.c_fc.weight', 'h.8.mlp.c_fc.bias', 'h.8.mlp.c_proj.weight', 'h.8.mlp.c_proj.bias', 'h.8.ln_2.weight', 'h.8.ln_2.bias', 'h.9.attn.bias', 'h.9.attn.c_attn.weight', 'h.9.attn.c_attn.bias', 'h.9.attn.c_proj.weight', 'h.9.attn.c_proj.bias', 'h.9.ln_1.weight', 'h.9.ln_1.bias', 'h.9.mlp.c_fc.weight', 'h.9.mlp.c_fc.bias', 'h.9.mlp.c_proj.weight', 'h.9.mlp.c_proj.bias', 'h.9.ln_2.weight', 'h.9.ln_2.bias', 'h.10.attn.bias', 'h.10.attn.c_attn.weight', 'h.10.attn.c_attn.bias', 'h.10.attn.c_proj.weight', 'h.10.attn.c_proj.bias', 'h.10.ln_1.weight', 'h.10.ln_1.bias', 'h.10.mlp.c_fc.weight', 'h.10.mlp.c_fc.bias', 'h.10.mlp.c_proj.weight', 'h.10.mlp.c_proj.bias', 'h.10.ln_2.weight', 'h.10.ln_2.bias', 'h.11.attn.bias', 'h.11.attn.c_attn.weight', 'h.11.attn.c_attn.bias', 'h.11.attn.c_proj.weight', 'h.11.attn.c_proj.bias', 'h.11.ln_1.weight', 'h.11.ln_1.bias', 'h.11.mlp.c_fc.weight', 'h.11.mlp.c_fc.bias', 'h.11.mlp.c_proj.weight', 'h.11.mlp.c_proj.bias', 'h.11.ln_2.weight', 'h.11.ln_2.bias']
Never mind, resolved now!!
Thanks!!
How did you resolve?
Most helpful comment
How did you resolve?