Pytorch-lightning: Model checkpoint claims test_step() not defined

Created on 1 May 2020  路  5Comments  路  Source: PyTorchLightning/pytorch-lightning

馃悰 Bug


I'm attempting to have my model be easily checkpointable for later testing. I have no issues with it creating the checkpoints and loading the model in as such seems to at least "work"

model = MyCoolModel.load_from_checkpoint(checkpoint_path, tags_csv=meta_path)

With checkpoint_path pointing towards the .ckpt file and meta_path the tags.csv. Now, my model in normal running works perfectly fine, I have working training epochs, validation steps, and a final test step called at the end. The problem begins when I load my model in I am greeted by an error saying I have never defined test_step()

Traceback (most recent call last):                                                                                                                                                                                 
  File "main.py", line 74, in <module>                                                                                                                                                                             
    run_model(hparams)                                                                                                                                                                                             
  File "main.py", line 64, in run_model                                                                                                                                                                            
    trainer.test()                                                                                                                                                                                                 
  File "/users2/mmatero/anaconda3/envs/project/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 904, in test                                                                                   
    self.run_evaluation(test_mode=True)                                                                                                                                                                            
  File "/users2/mmatero/anaconda3/envs/project/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 329, in run_evaluation                                                                 
    raise MisconfigurationException(                                                                                                                                                                               
pytorch_lightning.utilities.exceptions.MisconfigurationException: You called `.test()` without defining model's `.test_step()`. Please define and try again 

Expected behavior

I have clearly defined both test_step and test_epoch_end within my model's definition and they run completely fine when not loading from a checkpoint. I reuse my validation calls since the only difference is the dataloader they're using, operations are exactly the same.

def test_step(self, batch, batch_idx):
     return self.validation_step(batch, batch_idx)

def test_epoch_end(self, outputs):
     return self.validation_epoch_end(outputs)

So I'd expect them to still be defined after loading. I had other issues with pytorch-lightning ignoring my test_step definitions on other versions (specifically 0.7.5) but I have downgraded to one that works for a normal train/val/test loop.

Environment

  • PyTorch Version (e.g., 1.0): 1.4.0
  • Lightning Version: 0.7.3
  • OS (e.g., Linux): Ubuntu 16.04
  • How you installed PyTorch (conda, pip, source): Conda
  • Python version: 3.8.1
  • CUDA/cuDNN version: 10.1
  • GPU models and configuration: Titan XP x3
bug / fix help wanted

Most helpful comment

All 5 comments

I guess that currently the resume_from_checkpoint flag is designed only to resume training state, not to load a trained model.
If you want to evaluate your model from checkpoint, you can do this: https://pytorch-lightning.readthedocs.io/en/latest/weights_loading.html#checkpoint-loading

Ah, I think you are correct in that the intended use it to resume training state. If I "hack it" and call .fit() with epochs=0 then I do not receive the error anymore while being able to run my tests.

Yeah, true. I just figured out that you still have to call trainer.fit to set up the trainer before you can test your model even if you use MyModel.load_from_checkpoint. This behavior is annoying to me, too.

UPD: no, you haven't. Just call trainer.test(model) and it internally calls trainer.fit with testing=True flag.

can you upgrade to latest version, there shall be better checks

Was this page helpful?
0 / 5 - 0 ratings

Related issues

baeseongsu picture baeseongsu  路  3Comments

chuong98 picture chuong98  路  3Comments

williamFalcon picture williamFalcon  路  3Comments

polars05 picture polars05  路  3Comments

srush picture srush  路  3Comments