Pytorch-lightning: How to use multiple metric monitors in ModelCheckpoint callback?

Created on 11 Aug 2020  ·  7Comments  ·  Source: PyTorchLightning/pytorch-lightning

❓ Questions and Help

What is your question?

How can I use multiple metric monitors in the ModelCheckpoint? In another way, how can I use multiple ModelCheckpoint callbacks?It seems that the Trainer only accepts a singleModelCheckpoint in the checkpoint_callback argument.

Code

site-packages/pytorch_lightning/trainer/callback_config.py", line 46, in configure_checkpoint_callback
    checkpoint_callback.save_function = self.save_checkpoint
AttributeError: 'list' object has no attribute 'save_function'

What's your environment?

  • OS: Ubuntu 16.04
  • Packaging: pip
  • Version: pytorch-lightning==0.9.0rc12
API / design Important discussion question

Most helpful comment

Do you plan to support it? It would be nice to be able to do the following:

# Save top3 models wrt precision
on_best_precision = pytorch_lightning.callbacks.ModelCheckpoint(
    filepath=filepath + "{epoch}-{precision}",
    monitor="precision",
    save_top_k=3,
    mode="max",
)
# Save top3 models wrt recall
on_best_recall = pytorch_lightning.callbacks.ModelCheckpoint(
    filepath=filepath + "{epoch}-{recall}",
    monitor="recall",
    save_top_k=3,
    mode="max",
)
# Save the model every 5 epochs
every_five_epochs = pytorch_lightning.callbacks.ModelCheckpoint(
    period=5,
    save_top_k=-1,
    save_last=True,
)
trainer = pl.Trainer(
    checkpoint_callback=[on_best_precision, on_best_recall, every_five_epochs],
)

and something similar for the early_stop_callback flag.

All 7 comments

Hi! thanks for your contribution!, great first issue!

We currently don't support multiple ModelCheckpoint callbacks.
For monitoring multiple metrics with the same callback, I think you have to use the Results object:
https://pytorch-lightning.readthedocs.io/en/latest/results.html#checkpoint-early-stop

Do you plan to support it? It would be nice to be able to do the following:

# Save top3 models wrt precision
on_best_precision = pytorch_lightning.callbacks.ModelCheckpoint(
    filepath=filepath + "{epoch}-{precision}",
    monitor="precision",
    save_top_k=3,
    mode="max",
)
# Save top3 models wrt recall
on_best_recall = pytorch_lightning.callbacks.ModelCheckpoint(
    filepath=filepath + "{epoch}-{recall}",
    monitor="recall",
    save_top_k=3,
    mode="max",
)
# Save the model every 5 epochs
every_five_epochs = pytorch_lightning.callbacks.ModelCheckpoint(
    period=5,
    save_top_k=-1,
    save_last=True,
)
trainer = pl.Trainer(
    checkpoint_callback=[on_best_precision, on_best_recall, every_five_epochs],
)

and something similar for the early_stop_callback flag.

ping @awaelchli

I think yes, one day we will manage that. However, there are soooo many edge cases we need to consider, and we haven't even figured out all of them for the case of a single checkpoint callback. We need to approach this with care. It should be done by someone who is very confident with how model checkpointing works in PL (solely my opinion of course, not saying you or anybody else can't do it, it's just generally a hard task in my opinion, because it connects to many parts of the trainer, e.g. validation, early stopping, resuming from checkpoints, formatting names for checkpoint files, topk management, saving the last, persisting state, ... the list goes on).

cc @PyTorchLightning/core-contributors

Hi, this would be a great addition to Lightning! Do you know roughly when you're planning to include this feature?

@bartmch You can pass in multiple callbacks to the Trainer, just make sure you have the filename set properly in the checkpoint callback so that the checkpoints don't collide.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

awaelchli picture awaelchli  ·  3Comments

baeseongsu picture baeseongsu  ·  3Comments

DavidRuhe picture DavidRuhe  ·  3Comments

williamFalcon picture williamFalcon  ·  3Comments

edenlightning picture edenlightning  ·  3Comments