Only invoke def setup(self, step: str) when calling trainer.test(net) if this has not been called before (e.g. trainer.fit(net)).
The setup function is described in the docs as: "use setup to do splits, and build your model internals".
Therefore, I wrote code that does the train-val-test function and some DataFrame labels transformation (e.g. label to one_hot) in this function.
A pretty common pattern is the following:
trainer.fit(net)
trainer.test(net)
Contrary to what I expected, I saw from my debug output that setup was invoked twice.
This is a waste of computational resources, and since I did the train-val split randomly, I do not have access to the indices that were used in either step (and possibly other issues such as the label transformation re-ordering the columns of which number represent which label).
Current situation, train and val step use same setup, test step uses another invoke of setup
I assume that it is more common that the train-val and test step of the trainer use the same setup code, than that setup does something special for only test (and not val).
As Lightning works by giving sensible defaults, and allowing you to hack at anything you want, the logic should be that setup should only be invoked once, and allow for a way to specify a special test setup function.
Have the trainer keep track of whether setup() has been invoked or not, so setup() can be skipped in trainer.test(net) if it was already invoked in trainer.fit(net). This way the common use case will benefit from less computation power and is more in line what is expected of the magic of Lightning.
I'm not sure what would be the best approach for users to set a separate setup call for testing.
Maybe something like: trainer.test(net, always_invoke_setup=True)?
Include some custom logic that checks if data has been initialized:
def setup(self, step: str):
if self.data is None:
# setup code
else:
# do nothing
Ideas formed by discussing this issue on the pytorch-lightning SLACK in the questions channel. Thanks goed to the people who replied.
we could also just split the method
fit_setup
test_setup?
@tullie
Or simply just don't calll setup('fit') when self.testing == True??
that doesn’t happen no?
it you call .test() only the setup(‘test’) gets called
Since it calls .fit() within test, it calls .setup('fit') too, I think:
https://github.com/PyTorchLightning/pytorch-lightning/blob/7b4db3045dcc9e6bb0b66e409b25bb2c7fa378f0/pytorch_lightning/trainer/trainer.py#L1033-L1048
oh! yeah that's a bug :)
mind submitting a PR?
nice catch!!
Most helpful comment
oh! yeah that's a bug :)
mind submitting a PR?
nice catch!!