I have been using ignite for quite some time now and recently I came across a change from the ModelCheckpoint I'm aware of the BC with v0.3.0.
Still I want to know if it was the intended behavior.
Before we would be able to do something like this to save parameters and optimizer:
import torch
from ignite.engine import Engine
from ignite.handlers import ModelCheckpoint
trainer = Engine(lambda batch: None)
handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True)
model = torch.nn.Linear(3, 3)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {"weights": model, "optimizer": optimizer})
This will save two files the optimizer /tmp/models/myprefix_optimizer_{iteration}.pth and the weights /tmp/models/myprefix_weights_{iteration}.pth.
With the breaking change now this is not working anymore because of these lines I think:
https://github.com/pytorch/ignite/blob/39c4a8e2ef9a2a6c3a6109abc395ad39316349d8/ignite/handlers/checkpoint.py#L212-L224
The filename is /tmp/models/myprefix_checkpoint_{iteration}.pth and contains the weights and the optimizer state_dict.
To have the same behavior I have to "duplicate" the last line like this
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {"weights": model})
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {"optimizer": optimizer})
Is this the new intended behavior for saving state_dict ?
@czotti since v0.3.0 ModelCheckpoint and Checkpoint store all objects to save in a single file instead of multiple files in previous versions. (test)
Idea is to simplify the loading of all objects to resume a training with a single torch.load:
to_load = {"model": model, ...}
checkpoint = torch.load(checkpoint_fp)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
Yes, to save in multiple files, we need to add the handler two times as you suggest.
What do you think ?
HTH
Oh I didn't know you could load the checkpoint like this, it's fine to me. Maybe add this code snippet in the documentation to show how to load the checkpoints because at first I was doing it by hand like this
checkpoint = torch.load(checkpoint_fp)
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
or with a for loop like the load_objects if I have more state_dict to load.
@czotti good idea, thanks ! Seems like Checkpoint.load_objects is not picked by sphinx to render the docs.
PS: if you would like to contribute to Ignite by improving our docs, do not hesitate to send a PR, it would be very helpful for us !