Right now, a lot of model setup happens in fit
and run_pretrain
routine. However, that model setup needs to happen before we run evaluate
, test
, etc, so we end up calling fit
even when we want to do testing. This makes the code hard to follow.
We should refactor model setup out into its own method that fit
, train
, test
, etc. all call to make things a little more clear. I need to do some more work to figure out what the current state is and what the new state we should aim for is, but I wanted to get an issue created to track this.
even tho validation will not be allowed outside training as mentioned in #770, it would still be nice to not have to do the below since
def test_dataloader(self):
# use val data loader
return self.val_dataloader()
# or create test data loader
...
def test_step(self, *args, **kwargs):
return self.validation_step(*args, **kwargs)
def test_end(self, *args, **kwargs):
return self.validation_end(*args, **kwargs)
Echoing some of the Slack discussion on a proposed API that could be part of this refactor. It's becoming clear that continuously adding trainer flags to accommodate every new feature isn't scalable. Trainer is becoming this monolithic catch all for all functions that take a Lightning module.
I propose we begin to split Trainer into smaller components such as Tester, Tuner, and Trainer.
This would turn the following code:
trainer = Trainer()
new_lr = trainer.scale_batch_size(model, ...)
model.hparams.lr = trainer.lr_finder(model, ...)
new_batch_size = trainer.scale_batch_size(model, ...)
model.hparams.batch_size = new_batch_size
trainer.fit(model)
trainer.test(model)
into:
tuner.lr_finder(model)
tuner.scale_batch_size(model)
Trainer(…).fit(model)
Tester(...).test(model)
Part of adding this API refactor, I hope, will set a framework for reusable Lightning functions (i.e. functions/classes that take LightningModules). For example, a user might make and share a new AwsDeployer class which takes the LightingModule and deploys it to a server.
In regards to implementation, there is a question of how we can share state between these components. I can see 3 options:
1.
We compose Lightning function classes out of other Lightning functions. E.g.
trainer = Trainer(...)
tuner = Tuner(trainer)
tuner.lr_finder(model)
trainer.fit(model)
This would be the shortest path to separation between components. All we'd need to do is make the Trainer functions that Tuner uses public and the main one, self.fit(...), already is.
2.
We add a LightingConfig dataclass which is passed to the Lighting functions. E.g.
lightning_config = LightningConfig(
logger=TensorBoardLogger,
distributed_backend='ddp',
gpus=4
)
Tuner(lightning_config).lr_finder(model)
Trainer(lightning_config).fit(model)
The LightingModule could cache results that are shared by both (e.g. preparing data, setting up dataloaders and optimizers, current distributed process state, etc...). Note that the lighting config could have different config requirements for different Lightning functions.
3.
We add the training config to the lightning module itself. E.g.
model = LightningModule(hparams, lightning_config)
tuner.lr_finder(model)
trainer.fit(model)
Again, shared results can be cached by the LightningModule.
I personally prefer option 2. although there is a concern that there's too much boilerplate having to pass these lightning_configs around.
@tullie thanks for the suggestion!
I’ll write some thoughts on this over the next few days while we think about the future directions of the project, usability, etc to see how these refactors might impact that.
But i do agree that the trainer functionality has grown quite a bit and some features might become easier with some refactoring. My main concern is making it so the user doesn’t have to think a bunch when doing things - which flags avoid (but you could also argue things like a tuner might be equally simple).
I feel like everything in DL is, ok do step 1 and 2 and 3 and 4 and now this simple feature can be used. Then users end up switching the order, not reading docs, etc... then we end up in boilerplate land again haha.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
Most helpful comment
Echoing some of the Slack discussion on a proposed API that could be part of this refactor. It's becoming clear that continuously adding trainer flags to accommodate every new feature isn't scalable. Trainer is becoming this monolithic catch all for all functions that take a Lightning module.
I propose we begin to split Trainer into smaller components such as Tester, Tuner, and Trainer.
This would turn the following code:
into:
Part of adding this API refactor, I hope, will set a framework for reusable Lightning functions (i.e. functions/classes that take LightningModules). For example, a user might make and share a new AwsDeployer class which takes the LightingModule and deploys it to a server.
In regards to implementation, there is a question of how we can share state between these components. I can see 3 options:
1.
We compose Lightning function classes out of other Lightning functions. E.g.
This would be the shortest path to separation between components. All we'd need to do is make the Trainer functions that Tuner uses public and the main one, self.fit(...), already is.
2.
We add a LightingConfig dataclass which is passed to the Lighting functions. E.g.
The LightingModule could cache results that are shared by both (e.g. preparing data, setting up dataloaders and optimizers, current distributed process state, etc...). Note that the lighting config could have different config requirements for different Lightning functions.
3.
We add the training config to the lightning module itself. E.g.
Again, shared results can be cached by the LightningModule.
I personally prefer option 2. although there is a concern that there's too much boilerplate having to pass these lightning_configs around.