Pytorch-lightning: Cross validation feature

Created on 14 Feb 2020  路  31Comments  路  Source: PyTorchLightning/pytorch-lightning

馃殌 Feature

Cross-Validation is a crucial model validation techniques for assessing how the model generalizes on new data.

Motivation

Research papers usually require cross-validation. From my point of view, this kind of feature would simplify the work of researches.

Pitch

I want to pass a parameter to the Trainer object to specify that I want to train the model on K-folds.

In the case that nobody wants to make a PR, I can start working on that.

discussion enhancement good first issue help wanted

Most helpful comment

what if we just integrate with sklearn cross validation? this can be the start of supporting sklearn interop

All 31 comments

I think that the cleaner way would some abstraction above the dataloader, because cross-validation is just systematic train/test on a particular dataset... Anyway, a PR is welcome!
@BraveDistribution may you pls a bit more describe how do you plan to implement or make a draft PR and we can talk about it there :robot:

@Borda, I don't have any plan how to implement it because I wasn't working on that till now.

If I have any questions I will post it here, if not I will make a PR directly.

what if we just integrate with sklearn cross validation? this can be the start of supporting sklearn interop

How would you propose that @williamFalcon?

In my "own" library I split the datasets into K folders by using my own script (you can use k-fold or stratified k-fold or any of the scikit methods).

dataset/k_0/train
dataset/k_0/test

dataset/k_1/train
dataset/k_1/test

Then I trained and evaluated K neural networks and finally I just grab all the results and saved out the mean of acc, f1 and other metrics.

That of course means you wasted space on HDD which equals to (K-1) * size of the dataset. We shouldn't be implementing that approach.


I think we should add new parameter into trainer which can be something like GridSearchCV in scikit-learn

cvint, cross-validation generator or an iterable, optional
Determines the cross-validation splitting strategy. Possible inputs for cv are:
None, to use the default 5-fold cross validation,
integer, to specify the number of folds in a (Stratified)KFold,
CV splitter,
An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the estimator is a classifier and y is either binary or multiclass, StratifiedKFold is used. In all other cases, KFold is used.

what if we just integrate with sklearn cross validation? this can be the start of supporting sklearn interop

@williamFalcon skorch has a nice implementation. https://github.com/skorch-dev/skorch/blob/f94466e272f6f325898359fecb9a7c004354af7f/skorch/dataset.py#L212

check use case in #1393

By passing data loaders directly to the Trainer my CV loop looks like this:

for fold, (train_idx, valid_idx) in enumerate(kfold.split(train_df):
    train_loader = create_dataloader(train_df.iloc[train_idx])
    valid_loader = create_dataloader(train_df.iloc[valid_idx])

    # Folder hack
    tb_logger = TensorBoardLogger(save_dir=OUTPUT_PATH, name=f'{args.model_name}', version=f'fold_{fold + 1}')
    os.makedirs(OUTPUT_PATH / f'{args.model_name}, exist_ok=True)
    checkpoint_callback = ModelCheckpoint(filepath=tb_logger.log_dir + "/{epoch:02d}-{val_metric:.4f}", 
                                          monitor='val_metric', mode='max')

    model = YourPLModule(args)
    trainer = pl.Trainer(logger=tb_logger, early_stop_callback=early_stop_callback, checkpoint_callback=checkpoint_callback)
    trainer.fit(model, train_dataloader=train_loader, val_dataloaders=valid_loader)

Note that the folder hack is from https://github.com/PyTorchLightning/pytorch-lightning/issues/1207

it could be a nice feature as we have now the LR finder...
@PyTorchLightning/core-contributors any other suggestions?
@Anjum48, I would say draft a PR would be nice...

I wouldn't integrate this to fit or trainer init, but to a separate function internally calling fit

I wouldn't integrate this to fit or trainer init, but to a separate function internally calling fit

I agree, that's why I proposed to do it similar as LR finder... lol

We should also somehow include the CV results into tensorboard, to provide scientists easy way to check the quality of their models. I don't know much about tensorboard, so I don't know whether that's possible.

Or, we should at least save the final results into json / pickle file.

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Are there any news on this?

@axkoenig how would you do it, Write a wrapper over a Trainer and perform the fold splitting followed by train-test?

I think, we could have something like that in bolts, but it is very hard to generalize this, since it always depends on how you want to split your data.

I think we could provide two options:

  1. Either users provide a single train_dataloader that we split into K new dataloaders with non-overlapping subsets of data, and perform the cross validation from them
  2. Users provide K train_dataloaders and K test_dataloaders and we run cross validation on them (basically calling trainer.fit iteratively)

@SkafteNicki I think this would be a good idea to start.

However, we might also want to have some stratified splitting and not just random splitting, which may become more difficult, since we would have to assume things (like structure, dtype etc.) about these batches.

In general, we should also keep in mind, that we may not want to only split for train and test but also for validation sets/data loaders

@justusschock completely agree, I think that v1 of this feature should be very simple just random splitting. My proposed option 2. would allow the user to provide their own stratified dataloaders.

In v2 we can begin to figure out how to do more advance stuff/better integration. The main problem (in my view), is that we are working with dataloaders and not datasets, so to get dataset statistics (like class balance for stratified splitting) we need to explicit run over the dataset and enforce a lot of structure in the batches (as you mention).

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Hi! Is there an update on this issue? Due to the ubiquity of the cross val strategy it could be a quite significant addition to pl

@astenuz so we currently have a freeze on new features until the v1.0 release, since we want to focus on getting a very stable release. After v1.0 this is definitely something we would like to be a part of lightning.

@SkafteNicki should this be a DataModule feature, as mentioned in #4287 ? Like the DataModule itself provides k dataloaders like you mentioned here.

cc @edenafek

let鈥檚 pick this back up now

@ananyahjha93 the first question is how it should be integrated in lightning:
1) should trainer have a k_fold init argument?
2) should fit have a k_fold argument?
3) should trainer have a new method (cross_validate)
4) should this be a plugin?
5) should this be a completely new object wrapping around trainer (CV(Trainer(...)))?

I actually like the idea of having a separate class (CV) and some function in the data module for that. This way we would still have the trainer to train separate networks, but don't further bloat it's state.

However I'd prefer the interface to have the CV construct trainers internally by passed args. So something like this:

class CV:
    def __init__(self, *args, **kwargs):
        self.trainer_args = args
        self.trainer_kwargs = kwargs

    def fit(model, data_module):
        for loaders in data_module.get_kfold():
            fold_model = deepcopy(model)
            yield Trainer(*self.trainer_args, **self.trainer_kwargs).fit(model, loaders)

I am also in favor of a new separate class. Another thing is that the CV object probably will have some parameters of its own:
1) should the fitting be done in parallel (then we need to figure out how to map individual fit to each device)
2) should the cv be stratified (maybe not in v1 of this feature)
3) ...

I think that integration with optuna cross-validation would be a great match.

that鈥檚 already supported today. i think they tutorials about it as well no?

but generally we want to make sure we build general tools that support any option like optuna.

I have not seen tutorials doing cross validation with pytorch-lightning neither pytorch-lightning + Optuna cross-val.

I agree with you that the feature should be general.

@SkafteNicki I think for v1 the folds could run sequentially and the data_module could have a method which creates the loader (probably without stratification in v1, but can be overwritten by user). Also it is not possible to stratify every kind of training :D

Any specific plans on this? I have been trying to implement something like https://github.com/PyTorchLightning/pytorch-lightning/issues/839#issuecomment-714273956 but I am running into some rough edges like managing the loggers across folds, or checkpoints. There's also open questions about how to deal with the test parts.

I'd be happy to work on a PR given some guidance on how you'd like this implemented!

Was this page helpful?
0 / 5 - 0 ratings