I want to resume training not only in terms of num_epochs, but also in term of number of iterations?
How can I do it?
And the issue #470 can only resume training on number of epochs
@jyzhang-bjtu the following should work:
model.load_dict(saved_model_dict)
optimizer.load_dict(saved_optimizer_dict)
lr_scheduler.load_dict(saved_lr_scheduler_dict)
# Redefine the trainer as previously
trainer = ...
data = ...
# Set the initial epoch
resume_epoch = 2 # zero-based
@engine.on(Events.STARTED)
def resume_training(engine):
engine.state.iteration = resume_epoch * len(engine.state.dataloader)
engine.state.epoch = resume_epoch
trainer.run(data, max_epochs=max_epochs)
HTH
@vfdev-5 maybe we can add this to the FAQ
@anmolsjoshi good idea ! if you can send a PR, it would be perfect :)
Thanks for all replies. However, I found that there are some problems to be fixed.
if we restart at a given iteration, then the dataloader can not restart at the same position.
According to the source of engine, the restarted engine.state.epoch should be resume_epoch - 1.
For the problem 2, I suggest put the
self.state.epoch += 1
at the end of the while block.
And the same for
self.state.iteration += 1
if we restart at a given iteration, then the dataloader can not restart at the same position.
@jyzhang-bjtu resuming from an iteration is not that simple. We have a WIP PR on that #182.
Mainly, the question is how to efficiently "fast forward" an iterator.
According to the source of engine, the restarted engine.state.epoch should be
resume_epoch - 1.
You are free to modify resuming handler as you wish
@engine.on(Events.STARTED)
def resume_training(engine):
engine.state.iteration = (resume_epoch - 1) * len(engine.state.dataloader)
engine.state.epoch = (resume_epoch - 1)
Thanks a lot!
I provide some idea from my experience.
We can use the engine.state.iteration as an indicator for the dataloader.
If the dataloader has the method seek, then it can resume from the given position.
I close the issue as answered. For more details on how to resume the training, please see here