If restarting the training and reloading the model, the epoch that the checkpoint had just completed is restarted rather than beginning the next.
When a checkpoint upon epoch end is saved, restarting it should resume its state and start the next epoch.
This seems as simple as replacing the line above with self.current_epoch = checkpoint['epoch'] + 1 since the checkpointers save at the end of validation and the main loop runs from the current epoch.
We should probably also increase the global step by 1 since this happens after saving the checkpoint.
https://github.com/PyTorchLightning/pytorch-lightning/blob/edd4a87fb0e976e004a82c7a3491598ab65cafbd/pytorch_lightning/trainer/training_io.py#L388
I'll add a test in whilst I do it.
@williamFalcon @Borda Any other thoughts if I put a PR in with this? The test should presumably go in test_trainer.py?
That works when saving at epoch end but there's many cases of saving during an epoch as well (e.g. for a very large dataset). Both the epoch number and the global step are technically correct upon saving mid or end epoch but when resuming, the loop starts at the beginning. The best solution is to make the resume reliably restart precisely where in the loop it left off. I'm not that familiar with this code but guess one should then also save the batch_idx.
I can think of at least one way to do it, although it's not ideal:
None or something. Two Problems I see:
Currently I've put in a PR (#866) that deals with the off by one when loading a checkpoint from epoch end. I've added a warning when loading a mid-epoch checkpoint that says resuming training is not reliable and to consider loading an end of epoch checkpoint.
If it's preferable I can also rerun the previous epoch when we detect mid-epoch checkpoints, but this technically means you run for more epochs than you would expect/report, so I'm not sure if this is a good idea.
I'd suggest a new issue and discussion on how to resume mid epoch checkpoints, since we have no way of ensure data set states are preserved, and close this issue with the PR I have up.
Great, a good compromise for now
Most helpful comment
https://github.com/PyTorchLightning/pytorch-lightning/blob/edd4a87fb0e976e004a82c7a3491598ab65cafbd/pytorch_lightning/trainer/training_io.py#L389
This seems as simple as replacing the line above with
self.current_epoch = checkpoint['epoch'] + 1since the checkpointers save at the end of validation and the main loop runs from the current epoch.We should probably also increase the global step by 1 since this happens after saving the checkpoint.
https://github.com/PyTorchLightning/pytorch-lightning/blob/edd4a87fb0e976e004a82c7a3491598ab65cafbd/pytorch_lightning/trainer/training_io.py#L388
https://github.com/PyTorchLightning/pytorch-lightning/blob/edd4a87fb0e976e004a82c7a3491598ab65cafbd/pytorch_lightning/trainer/training_loop.py#L411
https://github.com/PyTorchLightning/pytorch-lightning/blob/edd4a87fb0e976e004a82c7a3491598ab65cafbd/pytorch_lightning/trainer/training_loop.py#L430
I'll add a test in whilst I do it.
@williamFalcon @Borda Any other thoughts if I put a PR in with this? The test should presumably go in
test_trainer.py?