Transformers: Use finetuned-BART large to do conditional generation

Created on 4 May 2020  路  9Comments  路  Source: huggingface/transformers

Hi

I am using a slightly old tag of ur repo where BART had run_bart_sum.py. I finetuned bart-large on a custom data set and want to do conditional generation

from transformers import BartTokenizer, BartForConditionalGeneration
import torch

model = BartForConditionalGeneration.from_pretrained('bart-large')
tokenizer = BartTokenizer.from_pretrained('bart-large')

ARTICLE_TO_SUMMARIZE = "President Donald Trump's senior adviser and son-in-law, Jared Kushner, praised the administration's response to the coronavirus pandemic as a \"great success story\" on Wednesday -- less than a day after the number of confirmed coronavirus cases in the United States topped 1 million. Kushner painted a rosy picture for \"Fox and Friends\" Wednesday morning, saying that \"the federal government rose to the challenge and this is a great success story and I think that that's really what needs to be told.\""


# model = BartForConditionalGeneration.from_pretrained('./bart_sum/checkpointepoch=2.ckpt')
# tokenizer = BartTokenizer.from_pretrained('./bart_sum/checkpointepoch=2.ckpt')

model = BartForConditionalGeneration.from_pretrained('bart-large')
tokenizer = BartTokenizer.from_pretrained('bart-large')
state = torch.load('./bart_sum/checkpointepoch=2.ckpt',map_location='cpu')
model.load_state_dict(state)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()


inputs = tokenizer.batch_encode_plus(
    [ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
summary_ids = model.generate(
    inputs['input_ids'], num_beams=1, max_length=512, early_stopping=True)


print([tokenizer.decode(g, skip_special_tokens=True,
                        clean_up_tokenization_spaces=False)
       for g in summary_ids])


I tried both loading the finetuned checkpoint directly as well as loading bart-large and setting state dict

For former it gives me

Traceback (most recent call last):
  File "generate.py", line 10, in <module>
    model = BartForConditionalGeneration.from_pretrained('./bart_sum/checkpointepoch=2.ckpt')
  File "/datastor/Softwarez/miniconda3/lib/python3.7/site-packages/transformers/modeling_utils.py", line 438, in from_pretrained
    **kwargs,
  File "/datastor/Softwarez/miniconda3/lib/python3.7/site-packages/transformers/configuration_utils.py", line 200, in from_pretrained
    config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
  File "/datastor/Softwarez/miniconda3/lib/python3.7/site-packages/transformers/configuration_utils.py", line 252, in get_config_dict
    config_dict = cls._dict_from_json_file(resolved_config_file)
  File "/datastor/Softwarez/miniconda3/lib/python3.7/site-packages/transformers/configuration_utils.py", line 344, in _dict_from_json_file
    text = reader.read()
  File "/datastor/Softwarez/miniconda3/lib/python3.7/codecs.py", line 322, in decode
    (result, consumed) = self._buffer_decode(data, self.errors, final)
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 0: invalid start byte

For latter
Unexpected key(s) in state_dict: "epoch", "global_step", "checkpoint_callback_best", "optimizer_states", "lr_schedulers", "state_dict", "hparams", "hparams_type".

Most helpful comment

If you used pytorch-lightning for training then you can load the weights from checkpoint as follows

ckpt = torch.load('./bart_sum/checkpointepoch=2.ckpt')
model.load_state_dict(ckpt['state_dict'])

once you load the weights this way then save the model using the .save_pretrained method so that next time you can load it using .from_pretrained

All 9 comments

If you used pytorch-lightning for training then you can load the weights from checkpoint as follows

ckpt = torch.load('./bart_sum/checkpointepoch=2.ckpt')
model.load_state_dict(ckpt['state_dict'])

once you load the weights this way then save the model using the .save_pretrained method so that next time you can load it using .from_pretrained

@patil-suraj, I tried your suggestion on finetuned BART checkpoint; though this gives me the following error, P.S. Model and tokenizer used is "bart-large"
`---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
in ()
1 ckpt = torch.load('./OUTPUT_DIR/checkpointcheckpoint_ckpt_epoch_2.ckpt')
----> 2 model.load_state_dict(ckpt['state_dict'])

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
845 if len(error_msgs) > 0:
846 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 847 self.__class__.__name__, "\n\t".join(error_msgs)))
848 return _IncompatibleKeys(missing_keys, unexpected_keys)
849

RuntimeError: Error(s) in loading state_dict for BartForConditionalGeneration:
Missing key(s) in state_dict: "final_logits_bias", "model.shared.weight", "model.encoder.embed_tokens.weight", "model.encoder.embed_positions.weight", "model.encoder.layers.0.self_attn.k_proj.weight", "model.encoder.layers.0.self_attn.k_proj.bias", "model.encoder.layers.0.self_attn.v_proj.weight", "model.encoder.layers.0.self_attn.v_proj.bias", "model.encoder.layers.0.self_attn.q_proj.weight", "model.encoder.layers.0.self_attn.q_proj.bias", "model.encoder.layers.0.self_attn.out_proj.weight", "model.encoder.layers.0.self_attn.out_proj.bias", "model.encoder.layers.0.self_attn_layer_norm.weight", "model.encoder.layers.0.self_attn_layer_norm.bias", "model.encoder.layers.0.fc1.weight", "model.encoder.layers.0.fc1.bias", "model.encoder.layers.0.fc2.weight", "model.encoder.layers.0.fc2.bias", "model.encoder.layers.0.final_layer_norm.weight", "model.encoder.layers.0.final_layer_norm.bias", "model.encoder.layers.1.self_attn.k_proj.weight", "model.encoder.layers.1.self_attn.k_proj.bias", "model.encoder.layers.1.self_attn.v_proj.weight", "model.encoder.layers.1.self_attn.v_proj.bias", "model.encoder.layers.1.self_attn.q_proj.weight", "model.encoder.layers.1.self_attn.q_proj.bias", "model.encoder.layers.1.self_attn.out_proj.weight", "model.encoder.layers.1.self_attn.out_proj.bias", "model.encoder.layers.1.self_attn_layer_norm.weight", "model.encoder.layers.1.self_attn_layer_norm.bias", "model.encoder.layers.1.fc1.weight", "model.encoder.layers.1.fc1.bias", "model.encoder.layers.1.fc2...
Unexpected key(s) in state_dict: "model.final_logits_bias", "model.model.shared.weight", "model.model.encoder.embed_tokens.weight", "model.model.encoder.embed_positions.weight", "model.model.encoder.layers.0.self_attn.k_proj.weight", "model.model.encoder.layers.0.self_attn.k_proj.bias", "model.model.encoder.layers.0.self_attn.v_proj.weight", "model.model.encoder.layers.0.self_attn.v_proj.bias", "model.model.encoder.layers.0.self_attn.q_proj.weight", "model.model.encoder.layers.0.self_attn.q_proj.bias", "model.model.encoder.layers.0.self_attn.out_proj.weight", "model.model.encoder.layers.0.self_attn.out_proj.bias", "model.model.encoder.layers.0.self_attn_layer_norm.weight", "model.model.encoder.layers.0.self_attn_layer_norm.bias", "model.model.encoder.layers.0.fc1.weight", "model.model.encoder.layers.0.fc1.bias", "model.model.encoder.layers.0.fc2.weight", "model.model.encoder.layers.0.fc2.bias", "model.model.encoder.layers.0.final_layer_norm.weight", "model.model.encoder.layers.0.final_layer_norm.bias", "model.model.encoder.layers.1.self_attn.k_proj.weight", "model.model.encoder.layers.1.self_attn.k_proj.bias", "model.model.encoder.layers.1.self_attn.v_proj.weight", "model.model.encoder.layers.1.self_attn.v_proj.bias", "model.model.encoder.layers.1.self_attn.q_proj.weight", "model.model.encoder.layers.1.self_attn.q_proj.bias", "model.model.encoder.layers.1.self_attn.out_proj.weight", "model.model.encoder.layers.1.self_attn.out_proj.bias", "model.model.encoder.layers.1.self... `

Please let me know how to tackle this?

@pranavpawar3
here model should be an instance of the LighteningModule. Initialize the LighteningModule, then you'll be able to do it this way

ckpt = torch.load('./OUTPUT_DIR/checkpointcheckpoint_ckpt_epoch_2.ckpt')
model.load_state_dict(ckpt['state_dict'])

# save the inner pretrained model
model.model.save_pretrained('model_dir')

# then you can load it using BartForConditionalGeneration
BartForConditionalGeneration.from_pretrained('model_dir')

@patil-suraj Initiating model as LighteningModule instance worked, Thanks!!

@pranavpawar3 can I ask you to share how you initialized the LightningModule instance to make it compatible with the model you fine-tuned based on the pretrained bart-large model? I'm having the same issue. thanks!

@patil-suraj could you please show how to Initialize model as LighteningModule instance. Have the same problem with loading finetuned bart ckpt. Thanks in advance!

@sshleifer thanks for the link, meanwhile i managed to do what i wanted.
anyway will be glad to see further improvements for summarisation tasks.

for those who finetuned BART model with finetune_bart.sh and wants to load it in pytorch, the next thing worked for me.

class BartModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')  

    def forward(self):
        pass

ckpt = torch.load('./bart_sum/checkpointepoch=1.ckpt')

bart_model = BartModel()
bart_model.load_state_dict(ckpt['state_dict'])
bart_model.model.save_pretrained("working_dir")

Just merged a bunch of changes to the summarization finetuning code. Long description [here] https://github.com/huggingface/transformers/pull/4951.
Would love it if somebody could take the new README/code for a spin!

Some improvements (sorry to repeat myself):

  • you can finetune bart a lot faster with --freeze_encoder and --freeze_embeds.
  • you can collaborate with the community on hyperparams/modifications for the XSUM task using --logger wandb_shared
  • upgrade to pytorch_lightning==0.7.6
  • You get a huggingface style checkpoint associated with the .ckpt checkpoint using the new rouge2 based model checkpoint.
  • Rouge (the canonical summarization metric) is calculated at every val step, this is slow. So you can use --val_check_interval 0.1 --n_val 500 to compute rouge more frequently on a subset of the validation set.

It's probably not perfect at the moment, so I'd love to know if anything is confusing or broken, either here or in a new issue :)
Thanks!

@sshleifer hi, i checked changes.

It works well, thanks for automatic model saving to pytorch format.

Also good to see tips in readme to use bart-large-xsum for short summaries, i tried it with my dataset instead of bart-large and this improved my score!

Are there any tips for t5? What type of t5 model is better for short summaries?

Was this page helpful?
0 / 5 - 0 ratings

Related issues

yspaik picture yspaik  路  3Comments

lcswillems picture lcswillems  路  3Comments

siddsach picture siddsach  路  3Comments

fyubang picture fyubang  路  3Comments

guanlongtianzi picture guanlongtianzi  路  3Comments