Pytorch-lightning: Clarify the model checkpoint arguments

Created on 24 Oct 2020  路  12Comments  路  Source: PyTorchLightning/pytorch-lightning

馃悰 Proposals

This is not so much a bug report as an RFC to clarify the ModelCheckpoint callback arguments:

  • save_last: to me, this means that whenever we save a checkpoint, we save a checkpoint with filename "last.ckpt". This provides a pre-determined checkpoint name, which is very helpful for resuming from failures. Importantly, it should not determine when checkpoints are saved. Currently it's easy to confuse this parameter to mean "save the checkpoint after the last epoch," which I think should be split out as a separate argument. This distinction would also clarify the typing and validation: there's no need for it to be an Optional[bool]: either we save a checkpoint as "last.ckpt" or not. So it could be a regular bool.
  • There's an inefficiency right now where we generate the checkpoint dict twice if save_last=True. For techniques like ZeRO that deal with sharded optimizer states, each checkpoint dict creation triggers communications across all ranks. Instead, we should gather the checkpoint dict once, and then save to different file paths accordingly (cc @SeanNaren, @blefaudeux)

  • save_top_k: since monitor is None by default, this should force save_top_k to be -1. The counterargument is that this can cause storage concerns. But I think this is easily correctable on the user-side: configure save_top_k + monitor

  • period: we should rename this as every_n_epochs. this opens up extensions for checkpointing after every_n_steps during training and checkpointing after a specified time interval. With those extensions in mind, period is ambiguous. Another request here is to change the default filename from "{epoch}" to "{epoch}-{global_step}" to better support mid-epoch checkpointing

cc @awaelchli @carmocca

Checkpoint enhancement help wanted

Most helpful comment

Users might want to do both: e.g. save a checkpoint every 10,000 steps and at each epoch

Yes, but I would support that by allowing having multiple ModelCheckpoint callbacks.

ModelCheckpoint has become quite complex lately, so we should evaluate splitting it some time in the future.
Any further changes we do should line up with a thought out future API.
One possibility is separating the basic saving functionality from tracking the TopK of anything. It could be something similar to:

class ModelCheckpoint:
    verbose: bool = False
    save_weights_only: bool = False
    period: int = 1
    dirpath: Optional[Union[str, Path]] = None
    filename: Optional[str] = None
    symlink_last: bool = False  # what you propose to be save_last
    # This is missing a mechanism to track either epochs or steps

class TopKModelCheckpoint(Checkpoint):
    # Notice these are not optional anymore 
    monitor: str
    save_top_k: int = 1
    mode: str = "auto"
    save_on_end: bool = False  # what we currently call save_last

Just an idea. Open to modifications of course :smile:

Edits:

  • (22-12-2020): Remove prefix

All 12 comments

save_last: I agree, I believe this was the original intention when this feature was added. Definitely useful to have.
period: Regarding naming, we have a PR here: #3807 which needs to be finished, but implements exactly this :)

  • save_last: What about having a symlink_to_last: bool = False argument so there is always a symbolic link to the last saved epoch? It'd also avoid the inefficiency you talk about in your second point.
  • period: I think we should keep the name period and have a mechanism to either track epochs or steps.

Also if we allow tracking epochs or steps, save_last should not be exclusive to epochs

save_last: What about having a symlink_to_last: bool = False argument so there is always a symbolic link to the last saved epoch? It'd also avoid the inefficiency you talk about in your second point.

I didn't see symlink supported from the fsspec API, so this could be challenging without that. but yeah that's the general idea: maybe we call it save_as_last ? save_most_recent ? open to name ideas here haha

period: I think we should keep the name period and have a mechanism to either track epochs or steps.

Users might want to do both: e.g. save a checkpoint every 10,000 steps and at each epoch

Users might want to do both: e.g. save a checkpoint every 10,000 steps and at each epoch

Yes, but I would support that by allowing having multiple ModelCheckpoint callbacks.

ModelCheckpoint has become quite complex lately, so we should evaluate splitting it some time in the future.
Any further changes we do should line up with a thought out future API.
One possibility is separating the basic saving functionality from tracking the TopK of anything. It could be something similar to:

class ModelCheckpoint:
    verbose: bool = False
    save_weights_only: bool = False
    period: int = 1
    dirpath: Optional[Union[str, Path]] = None
    filename: Optional[str] = None
    symlink_last: bool = False  # what you propose to be save_last
    # This is missing a mechanism to track either epochs or steps

class TopKModelCheckpoint(Checkpoint):
    # Notice these are not optional anymore 
    monitor: str
    save_top_k: int = 1
    mode: str = "auto"
    save_on_end: bool = False  # what we currently call save_last

Just an idea. Open to modifications of course :smile:

Edits:

  • (22-12-2020): Remove prefix

I also agree the current ModelCheckpoint is confusing. I really like the proposal in the first message that would clarify a lot of things (epoch/steps, save_last, etc).

I don't think symlink should be used as you never really know on which fs you are saving the checkpoint. I would let the user deal with the storage usage things.

Ideally, a single ModelCheckpoint class would be best (IMO) instead of two but I guess it's a matter of taste at this point.

I have a custom checkpoint callback that inherits from the ModelCheckpoint callback to support every_n_steps by saving all(save_top_k=-1 and monitor=None) and skip if not every n steps ((global_step + 1) % every_n_steps != 0)

At the same time, I also want to store the last checkpoint(save_last=True) so that we could resume training from crash.

However, this exception prevent me from doing so:

if self.save_last:
    raise MisconfigurationException(
        'ModelCheckpoint(save_last=True, monitor=None) is not a valid configuration.'
        ' You can save the last checkpoint with ModelCheckpoint(save_top_k=None, monitor=None)'
    )

I'm wondering if we could relax that constraint by giving a warning instead?

I'm wondering if we could relax that constraint by giving a warning instead?

Yes! this was already discussed in slack. Feel free to open a PR.

Ananth S Oct 15th at 6:23 AM
Why does this ModelCheckpoint instantiation raise a misconfiguration error?
Isn't saving the last checkpoint independent of the monitor/topK?

Adrian  23 days ago
@carmocca do you remember?

carmocca  22 days ago
Because monitor=None already saves the last ckpt (but not named last.ckpt) so why would you want to save last again (but named last.ckpt) (edited) 

carmocca  22 days ago
Maybe we should have used the phrase "is a redundant configuration" instead of valid

carmocca  22 days ago
save_last only makes sense when something is being monitored

Ananth S  22 days ago
i see. we are using torchelastic for long training runs. using save_last gives us a checkpoint with name last.ckpt , which is a deterministic path name to find and resume training from

Ananth S  22 days ago
what we really want is "save the most recent checkpoint as last.ckpt" (edited) 

Adrian  22 days ago
That is the original meaning of save_last. To have just that name. It should be a copy of the same file epoch=...cktp

Adrian  22 days ago
I definitely agree this is useful to have

carmocca  21 days ago
Feel free to send a PR changing "raise MisconfigurationException" to a warning. Also changing "is not a valid" for "is a redundant"

save_last only makes sense when something is being monitored

why is that??
IMO save_last should have nothing to do with the monitor or save_top_k. why do we have this warning??

with

ModelCheckpoint(save_last=True, monitor=None, save_top_k=None|0|-1)

it raises this warning

but I think, with

ModelCheckpoint(save_last=True, monitor=anything, save_top_k=anything)

it should not raise any warning or exceptions w.r.t to save_last

why is that??

From my last comment:

Because monitor=None already saves the last ckpt (but not named last.ckpt) so why would you want to save last again (but named last.ckpt)

Because monitor=None already saves the last ckpt (but not named last.ckpt)

oh ok. Then I suggest it should be reverted back to an exception, with an additional condition of self.save_last and self.save_top_k == -1 since saving the same checkpoint twice doesn't make sense. If one needs to access the last checkpoint, the path available in besk_k_models with epoch key. Also, save_last should be independent of the monitor, I think.

the original feature of save_last was to just have a copy of epoch=x.ckpt saved as last.ckpt. This would happen regardless of other settings. See original PR: #1908 and original feature request: #1658.
The benefit of having this file is to be able to load/resume the checkpoint programmatically through last.ckpt without the need of knowing what the last epoch was. This is also what @ananthsub describes here in the original message.
I hope we can keep this feature as I find it is very useful.

I believe we all get confused because save_last can mean two different things.

  1. A copy of the last checkpoint saved (regardless of anything else). This would ideally be a symlink.
  2. A checkpoint of the model just before training is over.

where save_last corresponds to (2). Note that (1) and (2) are the same thing in some circumstances.

In https://github.com/PyTorchLightning/pytorch-lightning/issues/4335#issuecomment-716051864 I propose having (1) and (2) as symlink_to_last and save_on_end respectively for clarity.

Also, save_last should be independent of the monitor, I think.

It should if we are talking about (1). If we are talking about (2), it is redundant because if monitor is None we will save every validation run.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

justusschock picture justusschock  路  3Comments

williamFalcon picture williamFalcon  路  3Comments

srush picture srush  路  3Comments

DavidRuhe picture DavidRuhe  路  3Comments

anthonytec2 picture anthonytec2  路  3Comments