Pytorch-lightning: Any way to make Lightning work with petastorm custom DataLoaders?

Created on 24 Apr 2020  路  7Comments  路  Source: PyTorchLightning/pytorch-lightning

Is it possible to use petastorm (https://github.com/uber/petastorm) pytorch data loaders with pytorch lightning?

This issue is that petastorm's DataLoaders need to be re initiated for each epoch.
A sample code looks like this:

for epoch in range(1, loop_epochs + 1):
        train_loader = DataLoader(...)
        train(model, device, train_loader, args.log_interval, optimizer, epoch)
        test_loader = DataLoader(...)
        test(model, device, test_loader)

The dataloader keeps it's state, so refactoring the snippet as below breaks for epochs > 1:

train_loader = DataLoader(...)
test_loader = DataLoader(...)
for epoch in range(1, loop_epochs + 1):
        train(model, device, train_loader, args.log_interval, optimizer, epoch)
        test(model, device, test_loader)

Thanks for you help guys.

question

All 7 comments

Hi! thanks for your contribution!, great first issue!

I am not familiar with petastorm but if the issue is just about resetting the dataloaders every epoch, there is a Trainer flag for that:
https://pytorch-lightning.readthedocs.io/en/stable/pytorch_lightning.trainer.html#reload-dataloaders-every-epoch

Thanks @awaelchli I'll try it out.

Thanks @awaelchli, it worked as expected.
Do know if it's possible to call a function at the end of an epoch on the dataloaders?
Petastorm dataloaders work as context managers, and I'd like to call __exit__() on them at the end of an epoch.
Thanks again.

Yes it is, for example with a model hook:
https://pytorch-lightning.readthedocs.io/en/latest/hooks.html
Just override on_epoch_end in your LightningModule.
You can access the dataloaders via
self.trainer.train_dataloader, self.trainer.val_dataloaders (a list)

So using reload_dataloaders_every_epoch=True was giving me the error:

Fatal Python error: This thread state must be current when releasing

I instead used on epoch start to reset the dataloaders.

        def on_epoch_start(self):
            self.trainer.train_dataloader = self.train_dataloader()
            self.trainer.val_dataloaders = [self.val_dataloader()]

        def on_epoch_end(self):
            # Clean dataloaders
            self.trainer.train_dataloader.__exit__(None, None, None)
            for dl in self.trainer.val_dataloaders:
                dl.__exit__(None, None, None)

Now everything works properly. Thanks @awaelchli for the help.

great!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

baeseongsu picture baeseongsu  路  3Comments

williamFalcon picture williamFalcon  路  3Comments

edenlightning picture edenlightning  路  3Comments

awaelchli picture awaelchli  路  3Comments

DavidRuhe picture DavidRuhe  路  3Comments