Pytorch-lightning: Resume Training with trainer.fit(model) from specific values of current_epoch and global_step

Created on 22 Jan 2020  ·  15Comments  ·  Source: PyTorchLightning/pytorch-lightning

🚀 Feature

Currently, trainer.fit(model) ignores values of model.global_step and trainer.current_epoch and overwrites them with 0s at the beginning of each call.

The _Resume Training_ feature will enable you to "continue where you left off" with the training.

Motivation

This is useful for many research use-cases, where something new is being tried and user wants to "interrogate" their model dynamically and irregularly between subsequent training sessions.

When tracking the training, whether using custom logging or some standard tool like TensorBoard, most people will use current_epoch and / or global_step somewhere in their logging pipeline to "timestamp" their logs / metrics.

Resetting these values on each training session forces user to duplicate the management of current_epoch and global_step variables somewhere in their code, adding boilerplate code.

Additionally, it is desirable to have this the "state of training" of each model preserved between saving and loading of model checkpoints. It is always good to know passed_epochs of any particular checkpoint (e.g. "how many times did this particular model checkpoint see the full dataset so far").

Whether user is in a jupyter notebook environment or running python scripts is irrelevant, as in both cases it is useful and desired to have a mechanism for preservation of training progress of a model checkpoint from save to load operation.

Pitch

One could achieve this by manipulating trainer.max_nb_epochs, trainer.min_nb_epochs and obviously trainer.current_epoch inbetween subsequent calls to trainer.fit(model).

Minimum change proposal

  • trainer.fit(model) re-uses values of trainer.current_epoch and model.global_step.

Extended change proposal

  • expose model.save_to_checkpoint() to the user, as we already have model.load_from_checkpoint() available
  • implement preservation of "training state" (namely values of current_epoch and global_step) in save_to_checkpoint and load_from_checkpoint

    • ideally in a way such that user doesn't have to do trainer.current_epoch = model.passed_epochs between calls to model.load_from_checkpoint and trainer.fit(model)

Alternatives

Alternatively, we can choose not to support this feature, but I can imagine other (more specific) use-cases where this feature is useful than stated above, so if nothing else, at least I'm interested how many people would find this useful.

Additional context

In research, training sessions might be desirably long and interactive. Currently they are treated as atomic operations from user perspective, where one needs to think ahead of everything they will want to examine about their model during training. To add something requires adjustment of the logging pipeline and starting the training session from scratch.

It goes without saying that lots of things are desirable to examine sparsely. The constraint to have all of this handled via prepared logging routines enforces regularity of examination along with need to determine appropriate intervals for these examinations ahead of time, which does not seem too research-friendly.

Alternative is to start the training over each time, which is what most people probably do currently, and might suffice as long as you're not logging scalars / images / metrics regularly with training time-stamps (resetting of which causes huge mess in TensorBoard for example).

Discussion

The extent and realization of this change are open to further considerations.

Stemming from Slack: @williamFalcon @neggert let's discuss here.

API / design enhancement help wanted won't fix

All 15 comments

This would be very useful for hyperparameter tuning, especially with Hyperband.
Typically we train the model for an epoch, look at its accuracy, and compare that to other runs. If the accuracy isn't good enough, we stop the model, otherwise continue for another epoch.
This requires stopping training after each epoch, which currently I can't figure out how to do withpytorch_lightning.

Here's my hack to run the training epoch by epoch. Please let me know if there's a better way to do it.

    trainer = pl.Trainer(max_epochs=1)
    for i in range(100):
        if i == 0:
            trainer.fit(model)
        else:
            trainer.current_epoch = i
            trainer.max_epochs = i + 1
            trainer.train()

Thanks @tridao, my follow-up hack based on yours looks something like this:

class PausableModule(MyLightningModule):
    def __init__(self, *args):
        self.register_buffer('epochs_passed', torch.zeros([]))
        super().__init__(*args)

    def on_epoch_end(self):
        self.epochs_passed += 1

    def train_another_n_epochs(self, n):
        self.trainer.current_epoch = int(self.epochs_passed.data.item())
        self.trainer.max_epochs = self.trainer.current_epoch + n
        self.trainer.train()

This is of course just my personal workaround, until we have a proper solution, which might look something like this but also quite different, not sure what will be the proper way to do this.

@lmartak amazing. can you submit a PR?

@williamFalcon sure, but let's agree on a scope first?

I feel like this is a bit hacky and might be introducing some tech debt. For example, we might want to

  • do trainer.current_epoch = int(model.epochs_passed.data.item()) also when trainer is first associated with a model

    • here a sub-question arises, what is the correct way to associate a trainer with a model? Currently I'm doing a bit hacky trainer.max_epochs = 0 followed by trainer.fit(model), but I feel there should be more intuitive / user-friendly way?

  • remove the re-setting (setting to 0) of trainer.current_epoch (and possibly other relevant variables) by trainer.fit() on each call
  • and other questions arised in the issue.

or should I just go ahead and submit this minimum fix for start and we can discuss and finalize the scope in the PR?

no hacks haha. submit a thought-out solution along with tests and doc updates :)

@lmartak @tridao how do you feel about making a PR?

@Borda I probably won't find time before the conference deadline I'm targeting. If it's still open when I find time, I'll give it a shot (and write here when I start).

@lmartak how is it going here?

This is rather strange. I am surprised you can't stop training your model (in jupyter), evalutate it, and then keep training it. You can do it with plain pytorch easily.

Would be super useful for fit to behave this way instead of restarting from scratch.

???
you can on a notebook! (in fact the demos do this).

  1. call .fit on the notebook
  2. interrupt training
  3. call .test (on a different cell)
  4. call .fit again (on a different cell)

Oh for some reason if I interrupt my .fit and then run it, it resets my loss again.

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@drozzy hey any chance you are interested in doing this?

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

iakremnev picture iakremnev  ·  3Comments

mmsamiei picture mmsamiei  ·  3Comments

srush picture srush  ·  3Comments

monney picture monney  ·  3Comments

as754770178 picture as754770178  ·  3Comments