Transformers: how to continue training from a checkpoint with Trainer?

Created on 17 Sep 2020  ยท  2Comments  ยท  Source: huggingface/transformers

โ“ Questions & Help

Details

I am trying to continue training my model (gpt-2) from a checkpoint, using Trainer. However when I try to do it the model starts training from 0, not from the checkpoint. I share my code because I don't know where I'm making the mistake.

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from transformers import TextDataset,DataCollatorForLanguageModeling, AutoTokenizer, GPT2LMHeadModel, Trainer, TrainingArguments

tokenizer = AutoTokenizer.from_pretrained("gpt2-large")

train_dataset = TextDataset(
          tokenizer=tokenizer,
          file_path='textfile (1).txt',
          block_size=128)

data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=False,
    )

model = GPT2LMHeadModel.from_pretrained("checkpoint-9500").to(device) ##HERE I LOAD FROM CHECKPOINT

training_args = TrainingArguments(
    output_dir='./results',         # output directory
    num_train_epochs=4,              # total # of training epochs
    per_device_train_batch_size=1,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    #eval_dataset=validation_dataset,
    prediction_loss_only=True,
)

trainer.train()



Thanks a lot for the help.

Most helpful comment

Hi there, you have to pass the checkpoint path to the method Trainer.train to resume training:

trainer.train("checkpoint-9500")

If you set your logging verbosity to the INFO level (transformers.logging.set_verbosity_info()) you should then see information about the training resuming and the number of steps skipped.

All 2 comments

Hi there, you have to pass the checkpoint path to the method Trainer.train to resume training:

trainer.train("checkpoint-9500")

If you set your logging verbosity to the INFO level (transformers.logging.set_verbosity_info()) you should then see information about the training resuming and the number of steps skipped.

Great, thanks a lot for your help Sylvain.

Works perfect.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

yspaik picture yspaik  ยท  3Comments

alphanlp picture alphanlp  ยท  3Comments

guanlongtianzi picture guanlongtianzi  ยท  3Comments

chuanmingliu picture chuanmingliu  ยท  3Comments

hsajjad picture hsajjad  ยท  3Comments