It seems that when using DataModule to separate training logic and data loading,
of the five methods that should be called that are
prepare_data()
, setup()
, train_dataloader()
, val_dataloader()
and test_dataloader()
,
only the last three are actually used, witch is problematic since the datasets used by the data-loaders should be assigned in the setup()
.
Steps to reproduce the behavior:
Run this:
import torch
from pytorch_lightning import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer import Trainer
from torch.nn import L1Loss, Linear
from torch.optim import SGD
from torch.utils.data import DataLoader
class MyDataModule(LightningDataModule):
def __init__(self):
super().__init__()
def prepare_data(self):
print('in prepare_data, '
'this should be called before train_dataloader() but is not.')
def setup(self, stage):
print('in setup, '
'this should be called before train_dataloader() but is not.')
self.train_dataset = 'whatever'
def train_dataloader(self):
print('in train_dataloader')
return DataLoader(self.train_dataset)
class MyLightningModule(LightningModule):
def __init__(self):
super().__init__()
self.layer = Linear(1, 1)
self.loss_function = L1Loss()
def forward(self, x):
return self.layer(x)
def configure_optimizers(self):
return SGD(self.parameters(), lr=0.01)
def training_step(self, batch, batch_idx):
print("you won't even get here")
raise NotImplementedError
data_module = MyDataModule()
model = MyLightningModule()
trainer = Trainer(gpus=1)
trainer.fit(model, data_module)
this gives AttributeError: 'MyDataModule' object has no attribute 'train_dataset'
.
When entering train_dataloader()
, prepare_data()
and setup()
should already have been executed, and thus the train_dataset
attribute should exist.
IMHO, it comes from here
you're not specifying the datamodule
kwarg in trainer.fit() - your last line should look like this: trainer.fit(model, datamodule=data_module)
In this first iteration of LightningDataModule
, you have to call setup and prepare_data manually for the datamodule instance. We have it set up this way so if you don't want to use Lightning, you can use your datamodule's loaders with pure Pytorch. I thought of having them called implicitly in the PR, but ended up landing on this for now. I'm not sure if users would _always_ want these to run implicitly.
TL;DR: you can update your code to look like this:
# Init a datamodule
dm = MyDataModule()
# Manually call prepare_data and setup. You could put this at end of __init__ if you want
dm.prepare_data()
dm.setup()
model = MyLightningModule()
trainer = Trainer(gpus=1)
trainer.fit(model, datamodule=dm)
That being said, we're open to any ideas on making this more intuitive, so feel free to throw out some alternatives. 馃槃
is not true in 0.9.0rc2: a data module as second positional argument is taken care of here.
I don't have a global enough view to know what other users might want, so if it is a feature i'm fine with it.
I just saw that the manual call was in the docs, my bad for not looking far enough.
Anyway thank you for the clear answer ^^
@remisphere I totally didn't notice! You were completely right on the dm arg. things move fast haha.
Reopening actually, as I think your intended use is more user friendly.