Pytorch-lightning: Epoch end checkpoint restarts previous epoch

Created on 15 Feb 2020  路  5Comments  路  Source: PyTorchLightning/pytorch-lightning

馃悰 Bug

If restarting the training and reloading the model, the epoch that the checkpoint had just completed is restarted rather than beginning the next.

Expected behavior

When a checkpoint upon epoch end is saved, restarting it should resume its state and start the next epoch.

bug / fix

Most helpful comment

https://github.com/PyTorchLightning/pytorch-lightning/blob/edd4a87fb0e976e004a82c7a3491598ab65cafbd/pytorch_lightning/trainer/training_io.py#L389

This seems as simple as replacing the line above with self.current_epoch = checkpoint['epoch'] + 1 since the checkpointers save at the end of validation and the main loop runs from the current epoch.

We should probably also increase the global step by 1 since this happens after saving the checkpoint.
https://github.com/PyTorchLightning/pytorch-lightning/blob/edd4a87fb0e976e004a82c7a3491598ab65cafbd/pytorch_lightning/trainer/training_io.py#L388

https://github.com/PyTorchLightning/pytorch-lightning/blob/edd4a87fb0e976e004a82c7a3491598ab65cafbd/pytorch_lightning/trainer/training_loop.py#L411

https://github.com/PyTorchLightning/pytorch-lightning/blob/edd4a87fb0e976e004a82c7a3491598ab65cafbd/pytorch_lightning/trainer/training_loop.py#L430

I'll add a test in whilst I do it.

@williamFalcon @Borda Any other thoughts if I put a PR in with this? The test should presumably go in test_trainer.py?

All 5 comments

https://github.com/PyTorchLightning/pytorch-lightning/blob/edd4a87fb0e976e004a82c7a3491598ab65cafbd/pytorch_lightning/trainer/training_io.py#L389

This seems as simple as replacing the line above with self.current_epoch = checkpoint['epoch'] + 1 since the checkpointers save at the end of validation and the main loop runs from the current epoch.

We should probably also increase the global step by 1 since this happens after saving the checkpoint.
https://github.com/PyTorchLightning/pytorch-lightning/blob/edd4a87fb0e976e004a82c7a3491598ab65cafbd/pytorch_lightning/trainer/training_io.py#L388

https://github.com/PyTorchLightning/pytorch-lightning/blob/edd4a87fb0e976e004a82c7a3491598ab65cafbd/pytorch_lightning/trainer/training_loop.py#L411

https://github.com/PyTorchLightning/pytorch-lightning/blob/edd4a87fb0e976e004a82c7a3491598ab65cafbd/pytorch_lightning/trainer/training_loop.py#L430

I'll add a test in whilst I do it.

@williamFalcon @Borda Any other thoughts if I put a PR in with this? The test should presumably go in test_trainer.py?

That works when saving at epoch end but there's many cases of saving during an epoch as well (e.g. for a very large dataset). Both the epoch number and the global step are technically correct upon saving mid or end epoch but when resuming, the loop starts at the beginning. The best solution is to make the resume reliably restart precisely where in the loop it left off. I'm not that familiar with this code but guess one should then also save the batch_idx.

I can think of at least one way to do it, although it's not ideal:

  • Use the global step as stored and when resuming, load this into a (hidden?) trainer variable which is checked every training batch.
  • Skip each batch until the batch number is above this saved variable, at which point we can set it to None or something.
  • Run batches as normal

Two Problems I see:

  • When the data set is shuffled there would be no guarantee of seeing only new samples after the reload without somehow 'resuming' the dataloaders.
  • Any calls that are usually made during the batch wouldn't happen. For example, the tqdm update calls would need to be faked so it didn't appear to end the epoch early in the progress bars.

Currently I've put in a PR (#866) that deals with the off by one when loading a checkpoint from epoch end. I've added a warning when loading a mid-epoch checkpoint that says resuming training is not reliable and to consider loading an end of epoch checkpoint.
If it's preferable I can also rerun the previous epoch when we detect mid-epoch checkpoints, but this technically means you run for more epochs than you would expect/report, so I'm not sure if this is a good idea.

I'd suggest a new issue and discussion on how to resume mid epoch checkpoints, since we have no way of ensure data set states are preserved, and close this issue with the PR I have up.

Great, a good compromise for now

Was this page helpful?
0 / 5 - 0 ratings

Related issues

mmsamiei picture mmsamiei  路  3Comments

baeseongsu picture baeseongsu  路  3Comments

edenlightning picture edenlightning  路  3Comments

srush picture srush  路  3Comments

williamFalcon picture williamFalcon  路  3Comments