Pytorch-lightning: Add dataloader arg to Trainer.test()

Created on 6 Apr 2020  路  8Comments  路  Source: PyTorchLightning/pytorch-lightning

馃殌 Feature


It would be nice if you could use a model for inference using:
Trainer.test(model, test_dataloaders=test_loader)

Motivation

This will match the calling structure for Trainer.fit() and allow for test to be called on any dataset multiple times

Pitch

Here's a use case. After training a model using 5-fold cross-validation, you may want to stack the 5 checkpoints across multiple models, which will require a) out-of-fold (OOF) predictions and b) the 5 test predictions (which will be averaged). It would be cool if a & b could be generated as follows:

for f in folds:
    model1.load_from_checkpoint(f'path/to/model1_fold{f}.ckpt')
    trainer.test(model1,  test_dataloaders=valid_loader)
    trainer.test(model1,  test_dataloaders=test_loader)

    model2.load_from_checkpoint(f'path/to/model2_fold{f}.ckpt'))
    trainer.test(model2,  test_dataloaders=valid_loader)
    trainer.test(model2,  test_dataloaders=test_loader)

Alternatives

Maybe I'm misunderstanding how test works and there is an easier way? Or perhaps the best way to do this is to write an inference function as you would in pure PyTorch?

Additional context

Priority P0 discussion enhancement help wanted let's do it!

Most helpful comment

btw I'm interested in how to "train a model using 5-fold cross-validation" in PL.

All 8 comments

Hi! thanks for your contribution!, great first issue!

I am in favour of adding this option, but first, lets see how it fits the API
@williamFalcon any strong suggestion against? cc: @PyTorchLightning/core-contributors

test is meant to ONLY operate on the test set. it鈥檚 meant to keep people from using the test set when they shouldn鈥檛 haha (ie: only right before publication or right before production use).

additions that i鈥檓 not sure align well

  1. Trainer.test as an instance method. Why wouldn鈥檛 you just init the trainer? otherwise you won鈥檛 be able to test on distributed environments or configure the things you need like apex, etc.

additions that are good

  1. allowing the test function to take in a dataset. this also aligns with how fit works.
  2. fit should also not take a test dataloader (not sure if it does now).
  3. current .test already uses your test dataloader defined in the lightningmodule. so the ONLY addition we鈥檙e talking about here is allowing test to ALSO take in a dataloader and use that one only.

btw I'm interested in how to "train a model using 5-fold cross-validation" in PL.

Let's do this:

  1. Add a test_dataloader method to .test()
  2. remove the test_dataloader from .fit()?

btw I'm interested in how to "train a model using 5-fold cross-validation" in PL.

@Ir1d Try this:
https://www.kaggle.com/rohitgr/quest-bert

https://www.kaggle.com/rohitgr/quest-bert

Hey @rohitgr7! The link seems to be broken, could you point to any other resource? Thanks!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

baeseongsu picture baeseongsu  路  3Comments

DavidRuhe picture DavidRuhe  路  3Comments

monney picture monney  路  3Comments

justusschock picture justusschock  路  3Comments

maxime-louis picture maxime-louis  路  3Comments