It would be nice if you could use a model for inference using:
Trainer.test(model, test_dataloaders=test_loader)
This will match the calling structure for Trainer.fit() and allow for test to be called on any dataset multiple times
Here's a use case. After training a model using 5-fold cross-validation, you may want to stack the 5 checkpoints across multiple models, which will require a) out-of-fold (OOF) predictions and b) the 5 test predictions (which will be averaged). It would be cool if a & b could be generated as follows:
for f in folds:
model1.load_from_checkpoint(f'path/to/model1_fold{f}.ckpt')
trainer.test(model1, test_dataloaders=valid_loader)
trainer.test(model1, test_dataloaders=test_loader)
model2.load_from_checkpoint(f'path/to/model2_fold{f}.ckpt'))
trainer.test(model2, test_dataloaders=valid_loader)
trainer.test(model2, test_dataloaders=test_loader)
Maybe I'm misunderstanding how test works and there is an easier way? Or perhaps the best way to do this is to write an inference function as you would in pure PyTorch?
Hi! thanks for your contribution!, great first issue!
I am in favour of adding this option, but first, lets see how it fits the API
@williamFalcon any strong suggestion against? cc: @PyTorchLightning/core-contributors
test is meant to ONLY operate on the test set. it鈥檚 meant to keep people from using the test set when they shouldn鈥檛 haha (ie: only right before publication or right before production use).
additions that i鈥檓 not sure align well
additions that are good
btw I'm interested in how to "train a model using 5-fold cross-validation" in PL.
Let's do this:
btw I'm interested in how to "train a model using 5-fold cross-validation" in PL.
@Ir1d Try this:
https://www.kaggle.com/rohitgr/quest-bert
Hey @rohitgr7! The link seems to be broken, could you point to any other resource? Thanks!
@ArthDh Try this one: https://www.kaggle.com/rohitgr/roberta-with-pytorch-lightning-train-test-lb-0-710
Most helpful comment
btw I'm interested in how to "train a model using 5-fold cross-validation" in PL.