Pytorch-lightning: Only invoke setup() once, not in both trainer.fit() and trainer.test()

Created on 16 Jul 2020  Â·  5Comments  Â·  Source: PyTorchLightning/pytorch-lightning

🚀 Feature


Only invoke def setup(self, step: str) when calling trainer.test(net) if this has not been called before (e.g. trainer.fit(net)).

Motivation


The setup function is described in the docs as: "use setup to do splits, and build your model internals".
Therefore, I wrote code that does the train-val-test function and some DataFrame labels transformation (e.g. label to one_hot) in this function.
A pretty common pattern is the following:

trainer.fit(net)
trainer.test(net)

Contrary to what I expected, I saw from my debug output that setup was invoked twice.
This is a waste of computational resources, and since I did the train-val split randomly, I do not have access to the indices that were used in either step (and possibly other issues such as the label transformation re-ordering the columns of which number represent which label).

Current situation, train and val step use same setup, test step uses another invoke of setup
I assume that it is more common that the train-val and test step of the trainer use the same setup code, than that setup does something special for only test (and not val).
As Lightning works by giving sensible defaults, and allowing you to hack at anything you want, the logic should be that setup should only be invoked once, and allow for a way to specify a special test setup function.

Pitch


Have the trainer keep track of whether setup() has been invoked or not, so setup() can be skipped in trainer.test(net) if it was already invoked in trainer.fit(net). This way the common use case will benefit from less computation power and is more in line what is expected of the magic of Lightning.

I'm not sure what would be the best approach for users to set a separate setup call for testing.
Maybe something like: trainer.test(net, always_invoke_setup=True)?

Alternatives


Include some custom logic that checks if data has been initialized:

def setup(self, step: str):
    if self.data is None:
        # setup code
    else:
        # do nothing

Additional context


Ideas formed by discussing this issue on the pytorch-lightning SLACK in the questions channel. Thanks goed to the people who replied.

enhancement help wanted let's do it!

Most helpful comment

oh! yeah that's a bug :)

mind submitting a PR?

nice catch!!

All 5 comments

we could also just split the method

fit_setup
test_setup?

@tullie

Or simply just don't calll setup('fit') when self.testing == True??

that doesn’t happen no?

it you call .test() only the setup(‘test’) gets called

oh! yeah that's a bug :)

mind submitting a PR?

nice catch!!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

baeseongsu picture baeseongsu  Â·  3Comments

Vichoko picture Vichoko  Â·  3Comments

justusschock picture justusschock  Â·  3Comments

iakremnev picture iakremnev  Â·  3Comments

versatran01 picture versatran01  Â·  3Comments