I noticed that even when ModelCheckpoint is used, Trainer by default runs the test loop with the last weights, not the best weights saved by ModelCheckpoint. I believe the sensible default here is to run the test loop with the best weights saved by ModelCheckpoint.
Now that ModelCheckpoint has a pointer to the best weights, Trainer can replace the last weights with the best weights before running the test loop automatically.
Possibly, this could be another option to Trainer. I don't like this as much b/c this is the behavior most users would expect.
Something like this?
trainer.test(model, load_best_checkpoint=True)
I would make the load_best_checkpoint=True as default...
Yeah this should definitely be the default behavior.
Another question is, can we only do this when ModelCheckpoint is used since Trainer itself doesn’t keep track of the best weights? What if someone writes their own ModelCheckpoint? It seems like there needs to be a common interface that Trainer uses to retrieve the best weights for the test loop. The best weights could then be whatever the checkpoint_callback defines it to be. In this way, I don’t think we’d need to have yet another option on Trainer.
why not make it:
# default
test(..., checkpoint=‘best’)
test(..., checkpoint=PATH/CKPT)
with the option for a string ‘best’
and make this the default
why not make it:
# default test(..., checkpoint=‘best’) test(..., checkpoint=PATH/CKPT)with the option for a string ‘best’
and make this the default
very good and test(..., checkpoint=None) uses the last...
Nice! I like that idea. To summarize:
test() called checkpoint whose default value is best.ummm. i prefer None to disable it.
there will 100% be cases where people need to disable that haha.
ah yeah since using the last epoch weights is the current behavior, setting it to None (and using the last epoch weights) would effectively disable it. Let me know if my understanding is incorrect.
current behavior is equivalent to None
new default behavior should be “best”
Most helpful comment
I would make the
load_best_checkpoint=Trueas default...