Is it possible to use petastorm (https://github.com/uber/petastorm) pytorch data loaders with pytorch lightning?
This issue is that petastorm's DataLoaders need to be re initiated for each epoch.
A sample code looks like this:
for epoch in range(1, loop_epochs + 1):
train_loader = DataLoader(...)
train(model, device, train_loader, args.log_interval, optimizer, epoch)
test_loader = DataLoader(...)
test(model, device, test_loader)
The dataloader keeps it's state, so refactoring the snippet as below breaks for epochs > 1:
train_loader = DataLoader(...)
test_loader = DataLoader(...)
for epoch in range(1, loop_epochs + 1):
train(model, device, train_loader, args.log_interval, optimizer, epoch)
test(model, device, test_loader)
Thanks for you help guys.
Hi! thanks for your contribution!, great first issue!
I am not familiar with petastorm but if the issue is just about resetting the dataloaders every epoch, there is a Trainer flag for that:
https://pytorch-lightning.readthedocs.io/en/stable/pytorch_lightning.trainer.html#reload-dataloaders-every-epoch
Thanks @awaelchli I'll try it out.
Thanks @awaelchli, it worked as expected.
Do know if it's possible to call a function at the end of an epoch on the dataloaders?
Petastorm dataloaders work as context managers, and I'd like to call __exit__() on them at the end of an epoch.
Thanks again.
Yes it is, for example with a model hook:
https://pytorch-lightning.readthedocs.io/en/latest/hooks.html
Just override on_epoch_end in your LightningModule.
You can access the dataloaders via
self.trainer.train_dataloader, self.trainer.val_dataloaders (a list)
So using reload_dataloaders_every_epoch=True was giving me the error:
Fatal Python error: This thread state must be current when releasing
I instead used on epoch start to reset the dataloaders.
def on_epoch_start(self):
self.trainer.train_dataloader = self.train_dataloader()
self.trainer.val_dataloaders = [self.val_dataloader()]
def on_epoch_end(self):
# Clean dataloaders
self.trainer.train_dataloader.__exit__(None, None, None)
for dl in self.trainer.val_dataloaders:
dl.__exit__(None, None, None)
Now everything works properly. Thanks @awaelchli for the help.
great!