I am trying to continue training my model (gpt-2) from a checkpoint, using Trainer. However when I try to do it the model starts training from 0, not from the checkpoint. I share my code because I don't know where I'm making the mistake.
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from transformers import TextDataset,DataCollatorForLanguageModeling, AutoTokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
train_dataset = TextDataset(
tokenizer=tokenizer,
file_path='textfile (1).txt',
block_size=128)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=False,
)
model = GPT2LMHeadModel.from_pretrained("checkpoint-9500").to(device) ##HERE I LOAD FROM CHECKPOINT
training_args = TrainingArguments(
output_dir='./results', # output directory
num_train_epochs=4, # total # of training epochs
per_device_train_batch_size=1, # batch size per device during training
per_device_eval_batch_size=64, # batch size for evaluation
warmup_steps=500, # number of warmup steps for learning rate scheduler
weight_decay=0.01, # strength of weight decay
logging_dir='./logs', # directory for storing logs
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
#eval_dataset=validation_dataset,
prediction_loss_only=True,
)
trainer.train()
Thanks a lot for the help.
Hi there, you have to pass the checkpoint path to the method Trainer.train to resume training:
trainer.train("checkpoint-9500")
If you set your logging verbosity to the INFO level (transformers.logging.set_verbosity_info()) you should then see information about the training resuming and the number of steps skipped.
Great, thanks a lot for your help Sylvain.
Works perfect.
Most helpful comment
Hi there, you have to pass the checkpoint path to the method
Trainer.trainto resume training:If you set your logging verbosity to the INFO level (
transformers.logging.set_verbosity_info()) you should then see information about the training resuming and the number of steps skipped.