Pytorch-lightning: Add IterableDataset support

Created on 7 Oct 2019  路  19Comments  路  Source: PyTorchLightning/pytorch-lightning

Looks like currently there is no way to use an IterableDataset instance for training. Trying to do so results in a crash with this exception:

Traceback (most recent call last):
  File "main.py", line 12, in <module>
    trainer.fit(model)
  File "/home/akonstantinov/.miniconda3/envs/cpn/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 677, in fit
    self.__single_gpu_train(model)
  File "/home/akonstantinov/.miniconda3/envs/cpn/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 725, in __single_gpu_train
    self.__run_pretrain_routine(model)
  File "/home/akonstantinov/.miniconda3/envs/cpn/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 882, in __run_pretrain_routine
    self.__layout_bookeeping()
  File "/home/akonstantinov/.miniconda3/envs/cpn/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 436, in __layout_bookeeping
    self.nb_training_batches = len(self.train_dataloader)
  File "/home/akonstantinov/.miniconda3/envs/cpn/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 297, in __len__
    return len(self._index_sampler)  # with iterable-style dataset, this will error
  File "/home/akonstantinov/.miniconda3/envs/cpn/lib/python3.7/site-packages/torch/utils/data/sampler.py", line 212, in __len__
    return (len(self.sampler) + self.batch_size - 1) // self.batch_size
  File "/home/akonstantinov/.miniconda3/envs/cpn/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 57, in __len__
    raise TypeError('Cannot determine the DataLoader length of a IterableDataset')
TypeError: Cannot determine the DataLoader length of a IterableDataset
enhancement help wanted

Most helpful comment

Added in #405

To use:

Trainer(val_check_interval=100)

(checks val every 100 train batches interval)

@falceeffect please verify it works as requested. Otherwise we can reopen

All 19 comments

Got it.

Since it's impossible to know when to check validation, checkpoint or stop training the workaround is to add a really high number to the __len__ of your dataloader.

To support this we should modify the training loop so it does validation check, etc... every k batches. Might need to disable tqdm limit as well because we won't know the length.

The use case is for streaming data or database-type reads.

There is also a bit more simple case, when the length is actually known, but random access by index is not available. That is true in my case: my dataset generates samples on the fly, but always a fixed amount per epoch. Say, every epoch 10k samples are generated and fed into the model by batches on 100 samples.

so, actually all of this can be solved by adding a way to say how many batches you want per epoch. then everything just works out.

Trainer(max_epoch_batches=10000)

@neggert any thoughts?

In the past, I've handled this like by storing a num_batches attribute in custom batch sampler (which I needed to use for other reasons). Then we just do this:

    def __len__(self):
        return self.num_batches

    def _get_batch(self):
        ...

    def __iter__(self):
        return iter((self._get_batch() for _ in range(self.num_batches))

This is probably not a good general solution, as it would have been a lot of work if I hadn't been planning on using a custom batch sampler anyway.

For a general solution, think a max_epoch_batches arg is a good idea.

We do need to be a little bit careful, as there are some pitfalls around using an IterableDataset with multiple workers or nodes. It would be good to warn users about these like we do with DistributedSampler.

So, the resolution here is to add an argument:

Trainer(max_epoch_batches=10000) which calls validation set at that interval (and overrides all the other settings for this).

(open to a better name for this)

@Borda @neggert any of u guys want to take a stab at this?

Added in #405

To use:

Trainer(val_check_interval=100)

(checks val every 100 train batches interval)

@falceeffect please verify it works as requested. Otherwise we can reopen

Doesn't solve the problem

Same here. Still give an error

@calclavia which problem? compatibility? (@MikeScarp)

@williamFalcon The original issue reported here is not fixed by https://github.com/williamFalcon/pytorch-lightning/pull/405. I am still unable to train with an instance of ItrableDataset.

Right. Putting in val_check_interval works but it doesn't seem to circumvent Pytorch lightning asking for len of the dataset, which leads to a crash since IterableDataset doesn't support len call

@calclavia mind submitting a PR? where is the len being asked? i thought we specifically handled that case

I tried defining the __len__(self) in my dataset class (which inherits from torch.utils.data.IterableDataset)), but it still didn't work.
For me, the actual error occurs on the line 297 of torch/utils/data/dataloader.py, when it tries to call len(self._index_sampler).
It would be cool if pytorch-lightning supported a IterableDatasets by calling the right torch functions.

It appears there is a typo in the latest pip installable version 0.5.3.2 https://github.com/williamFalcon/pytorch-lightning/blob/0.5.3.2/pytorch_lightning/trainer/data_loading_mixin.py#L27.

It should be isinstance(self.get_train_dataloader().dataset, IterableDataset) instead of isinstance(self.get_train_dataloader(), IterableDataset)

This is later fixed in https://github.com/williamFalcon/pytorch-lightning/pull/549 but it has not been released.

@williamFalcon will you be able to make a new release soon since there was no release in December? Thanks!

Actually, even the latest master branch still has this problem.

Traceback (most recent call last):
  File "scripts/msmacro.py", line 119, in <module>
    main()
  File "scripts/msmacro.py", line 115, in main
    trainer.fit(model)
  File "/home/zhaohao/.anaconda3/envs/pytorch/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 417, in fit
    self.run_pretrain_routine(model)
  File "/home/zhaohao/.anaconda3/envs/pytorch/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 481, in run_pretrain_routine
    self.get_dataloaders(ref_model)
  File "/home/zhaohao/.anaconda3/envs/pytorch/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py", line 199, in get_dataloaders
    self.init_train_dataloader(model)
  File "/home/zhaohao/.anaconda3/envs/pytorch/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py", line 78, in init_train_dataloader
    self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
OverflowError: cannot convert float infinity to integer

@matthew-z Could you please reopen this issue or make another one?

I don't have the privilege to re-open a closed issue, so I will open a new one

Was this page helpful?
0 / 5 - 0 ratings