Pytorch-lightning: Trainer should run the test loop with the best weights when ModelCheckpoint is used

Created on 2 Jun 2020  ·  9Comments  ·  Source: PyTorchLightning/pytorch-lightning

🚀 Feature

Motivation


I noticed that even when ModelCheckpoint is used, Trainer by default runs the test loop with the last weights, not the best weights saved by ModelCheckpoint. I believe the sensible default here is to run the test loop with the best weights saved by ModelCheckpoint.

Pitch


Now that ModelCheckpoint has a pointer to the best weights, Trainer can replace the last weights with the best weights before running the test loop automatically.

Alternatives


Possibly, this could be another option to Trainer. I don't like this as much b/c this is the behavior most users would expect.

Additional context

enhancement help wanted

Most helpful comment

I would make the load_best_checkpoint=True as default...

All 9 comments

Something like this?
trainer.test(model, load_best_checkpoint=True)

I would make the load_best_checkpoint=True as default...

Yeah this should definitely be the default behavior.

Another question is, can we only do this when ModelCheckpoint is used since Trainer itself doesn’t keep track of the best weights? What if someone writes their own ModelCheckpoint? It seems like there needs to be a common interface that Trainer uses to retrieve the best weights for the test loop. The best weights could then be whatever the checkpoint_callback defines it to be. In this way, I don’t think we’d need to have yet another option on Trainer.

why not make it:

# default
test(..., checkpoint=‘best’)

test(..., checkpoint=PATH/CKPT)

with the option for a string ‘best’

and make this the default

why not make it:

# default
test(..., checkpoint=‘best’)

test(..., checkpoint=PATH/CKPT)

with the option for a string ‘best’

and make this the default

very good and test(..., checkpoint=None) uses the last...

Nice! I like that idea. To summarize:

  • add an option to test() called checkpoint whose default value is best.
  • if it's None, use the weights from the last epoch
  • if it's another string, treat it as a path.

ummm. i prefer None to disable it.
there will 100% be cases where people need to disable that haha.

ah yeah since using the last epoch weights is the current behavior, setting it to None (and using the last epoch weights) would effectively disable it. Let me know if my understanding is incorrect.

current behavior is equivalent to None

new default behavior should be “best”

Was this page helpful?
0 / 5 - 0 ratings