The only way I've found to resume the training of a model from the best checkpoint is to explicitly instance Trainer with following structure:
best_epoch = 250 ## ?? (Actually i dont know)
self.trainer = ptl.Trainer(
logger=TestTubeLogger(
save_dir=self.save_dir,
version=0 # fixed
),
resume_from_checkpoint=str(self.save_dir / 'default' / 'version_0' / 'checkpoints' / 'epoch={}.ckpt'.format(best_epoch)),
default_root_dir=self.save_dir, )
The problem is that is not natural to know the exact epoch which has the better model.
I think loading the best model is a pretty natural operation for most of the cases:
Training: You want to continue training from the best model.
Test: You want to test the best model.
I think a method to find the "best checkpoint" is the kind of boilerplate code this library tries to avoid.
Have you tried using trainer.checkpoint_callback.best_model_path ?
But how may the trainer know the best_model_path prior to loading?
I mean load the Trainer in a fresh runtime it's exactly what i want to do.
Oh, I'm not sure if there's anything to help with that yet. I guess right now we have to explicitly specify a path to the checkpoint.
Training: You want to continue training from the best model.
How can a new Trainer instance know the best model checkpoint you saved using another Trainer instance you used to train the model before?
Test: You want to test the best model.
you can just use trainer.test(ckpt_path='best')
Thank you for your comments.
About the Training scheme, I'm pretty sure the TubeTestLogger did the best model loading before i updated. The feature stopped working after updating PyTorch-lightning from 0.3 to 0.9.
About loading the best model Trainer instance I thought about picking the checkpoint path with the higher epoch from the checkpoint folder and use resume_from_checkpoint Trainer param to load it. I thought there'd be an easier way but I guess not. Anyway i'll keep this issue updated if i come up with any solution to this case.
I would recommend using Tensorboard logger instead TestTube logger, anyway as mentioned above the "best" is part of checkpointing, not logger :]
@Borda Can you explain why would you prefer Tensorboard logger? I thought Test Tube logger use was encouraged as was used so much on the tutorial.
@williamFalcon ^^
This is the simpler workaround Trainer and Logger scheme I made for the purpose of this issue, based on what I had.
version = 1
logger = TestTubeLogger(
save_dir=save_dir,
version=version # fixed to one to ensure checkpoint load
)
ckpt_folder = save_dir / 'default' / 'version_{}'.format(version) / 'checkpoints'
best_epoch = find_best_epoch(ckpt_folder)
self.trainer = ptl.Trainer(
logger=logger,
resume_from_checkpoint=str(ckpt_folder / 'epoch={}.ckpt'.format(best_epoch)),
)
And the find_best_epoch i defined was:
def find_best_epoch(ckpt_folder):
"""
Find the highest epoch in the Test Tube file structure.
:param ckpt_folder: dir where the checpoints are being saved.
:return: Integer of the highest epoch reached by the checkpoints.
"""
ckpt_files = listdir(ckpt_folder) # list of strings
epochs = [int(filename[6:-5]) for filename in ckpt_files] # 'epoch={int}.ckpt' filename format
return max(epochs)
I hope this may help someone.
Most helpful comment
How can a new Trainer instance know the best model checkpoint you saved using another Trainer instance you used to train the model before?
you can just use
trainer.test(ckpt_path='best')