Pytorch-lightning: Predict method for test set

Created on 10 Aug 2019  Â·  34Comments  Â·  Source: PyTorchLightning/pytorch-lightning

The main Lightning module requires to define the test_dataloader function. But I'm not able to find any method that requires the test loader as input. Is there a model.predict() method to call on the test set?

enhancement help wanted

Most helpful comment

any updates on a proper infernce with returning predictions?

All 34 comments

thanks for bringing that up! we use different scripts for test set, but makes sense to add a .predict function to trainer.

The proposed flow would be:

CASE 1: train, then test (proposed)

trainer = Trainer(...)
trainer.fit(model)

# the optional prediction call  which  will use  test set
# this makes sure  the  research is 100%  sure  the test set  was  not used in training
# linking  predict to trainer means multi-gpu and cluster  support  for test set  is  free
trainer.test()   


# in LightningModule, test_step would be called    

CASE 2: test from saved model (proposed)

model = CoolModel. load_from_metrics(weights_path='path/to/weights/', tags_csv='path/to/hparams.csv')   
trainer=Trainer()
trainer.test(model)

Case 3: model inference (current)

If you want to actually make predictions (maybe for a server, sytem, etc...) then just call forward in your lightningModule.

model = CoolModel.load_from_metrics(weights_path='path/to/weights/', tags_csv='path/to/hparams.csv')     
model.freeze()

y_hat = model(x)

How do flows 1, 2 sound? are these flows intuitive to you?

Yes, they are. And I like the idea that you can define the test_step freely.

Regarding the 3rd case: it's fine, but it might happen that the way you make predictions might be different from the training phase. For instance: you have two images referring to the same object (so same label) and in that case you want to take the most confident one among the two.
At the moment, the only way to do this would be to load the weights of the trained model and create a traditional loop, I guess.

@lorenzoFabbri so define another method (optional) .predict in cases where your .forward and .predict behavior is different?

so, if no .predict is defined, the model uses .forward. otherwise .predict?

Yes, it's a possibility. But I was just wondering whether it was already available. Sometimes you just forget to put model.eval() and nothing works as expected: the less code I have to write the better :)

not yet, but will add for next release this week. just want to make sure the API and usage are intuitive.

i agree, separate test script is a PITA right now.

If you need it faster, maybe submit a PR? @lorenzoFabbri (that way you code it once and it's available for everyone). I'll review it to give a sanity check for bugs

Quickly thinking through this, i think you can just pass in a flag to validate(test=True), and change where the forward calls get routed to (training_step, etc...)...

Otherwise, i'll add it later this coming week

I've been looking into adding a predict/test function. To have it work like case 2, I think we would need to extract all the setup that is done in fit into another function or generalize it somehow

Is anyone working on Case 1? I would like to use the functionality soon, so I can give it a try.
@williamFalcon Could you explain what you mean by

and change where the forward calls get routed to (training_step, etc...)

I don't get it :) Or does it simply mean to take care that validation_end() is not called when test=True?

I'm not working on it at the moment.

i’m finishing up some iclr work, so i can add this in the coming days.

do cases 1, 2 make intuitive sense to everyone?

but @expectopatronum if you want to give it a shot, i think we:
Case 1:

  1. rename run_validation to run_evaluation.
  2. pass in a flag “test” or “val” to the run_evaluation function
  3. add a new method called test that calls run_evaluation using the ”test” flag
  4. if the test flag is present, use test dataloader and call test_step if defined. if test_step is not defined, use validation_step

Case 2:
Same as above, but since fit is never called, call it in the test function with a “test” flag. then in pretrain_routine don’t sanity check or fit, just call the eval function and exit

last we need a test for each case

and docs

Thanks, I am working on it!

clarification on the flag, i think a flag

in_test_mode = True 

but do what makes sense to you and i’ll comment on the PR

in the Trainer is probably better than a string being passed around.

Sure, makes sense!
I guess it would also make sense to rename validate to evaluate and pass on the flag? (at least from my current understanding of the code)

yup.

A good way to think about it to reuse all the validation logic, but add a test flag to use the testing resources instead.

validation_step -> test_step   
val_dataloaders -> test_dataloaders   
etc...

So for testing we can also have multiple test_dataloaders? In the docu it's only mentioned for the val_dataloader

yup. we should support that

Great, I agree :)

Should I create a WIP pull request or should I wait until I think I am done?

For case 2 the current output would look like this:
1it [00:00, 78.39it/s, batch_nb=0, epoch=0, gpu=0, loss=0.000, test_loss=0.0152, v_nb=0]
(I am not sure where '1it' comes from yet)
I guess it would make sense to not show batch_nb, epoch, loss?
Also v_nb (what is that by the way?)

i think the validation call takes care to not log irrelevant stuff. 1it means 1 iteration. that number changes with the number of batches.

v_nb is experiment version number.
there’s a blacklist in the code, let’s just show what doesn’t get blacklisted.

WIP PR if you want. i’ll review it once the PR is completely done with tests (and has passed tests). then i’ll do a code review at that point

I have added 3 test functions:

  • test_multiple_test_dataloader (analogous to test_multiple_val_dataloader)
  • test_multiple_test_dataloader
  • test_running_test_after_fitting
    Where the last two are only there to check whether the code runs. I couldn't find any code that tests the functionality of the validation functions so I am not sure how to properly test the testing code. Looking forward to suggestions.

One thing that could easily be checked is whether it still runs when the model does not have a test_step but a validation_step by creating a new model.

Regarding the documentation: The structure of the Trainer documentation is not clear to me - where should I put the test() documentation? Create a new section below Validation loop in index.md?

@expectopatronum awesome PR! very key feature we were missing. Thanks for your work on this!

Merged to master

Is there any documentation on the predict functionalities implemented in the PR?

I am not sure what you mean. This PR implements test which is documented here

so this test functionality can only return metrics and is not the same as inference/predict?

any updates on a proper infernce with returning predictions?

@Diyago mind open a new issue if needed?

Was this page helpful?
0 / 5 - 0 ratings