Ignite: Running trainer for more than one epoch with torch.utils.data.IterableDataset

Created on 21 Apr 2020  ยท  5Comments  ยท  Source: pytorch/ignite

๐Ÿ› Bug description

Hello,
I don't know if I am missing something but I am currently trying to run my trainer for more than 1 epoch where each epoch has 5 iterations. However there is not a new iterator initialized over the dataloader and you get the warning below.

tensor([1])
Epoch [1/3]: [1/5]  20%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Š                                                                                                        [00:00<00:00]/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ignite/contrib/handlers/base_logger.py:124: UserWarning: Provided metric name 'loss' is missing in engine's state metrics: []
  "in engine's state metrics: {}".format(name, list(engine.state.metrics.keys()))
tensor([4])
Epoch [1/3]: [1/5]  20%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Š                                                                                                        [00:00<00:00]tensor([9])
Epoch [1/3]: [2/5]  40%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Œ                                                                              [00:00<00:00]tensor([16])
Epoch [1/3]: [3/5]  60%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–                                                    [00:00<00:00]tensor([25])
Epoch [1/3]: [5/5] 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ [00:00<00:00]
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ignite/engine/engine.py:465: UserWarning: Data iterator can not provide data anymore but required total number of iterations to run is not reached. Current iteration: 5 vs Total iterations to run : 15
  self.state.iteration, self.state.epoch_length * self.state.max_epochs

Now if I do not use the ignite package and write the loop to iterate over the number of epochs it works fine. Here is a code sample I have pertaining to my use case.

import torch
from ignite.engine import Engine, Events
from ignite.contrib.handlers import ProgressBar
import pdb

class DatasetUtils:
    def __init__(self, datapoints):
        self.datapoints = datapoints

    def load_data(self):
        for d in self.datapoints:
            yield d

    def create_examples(self, data):
        for datapoint in data:
            yield datapoint**2

    def __len__(self):
        return len(self.datapoints)


class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, helper):
        self.helper = helper

    def __iter__(self):
        data_iter = self.helper.load_data()
        for example in self.helper.create_examples(data_iter):
            yield example

    def __len__(self):
        return len(self.helper)

datapoints = [1, 2, 3, 4, 5]
helper = DatasetUtils(datapoints)

ds = MyIterableDataset(helper)

data_loader = torch.utils.data.DataLoader(ds, num_workers=1, batch_size=1)


def update(engine, batch):
    print(batch)

# Using ignite 
trainer = Engine(update)
pbar = ProgressBar(persist=True)
pbar.attach(trainer, metric_names=["loss"])
trainer.run(map(lambda x: x, data_loader), epoch_length=5//1, max_epochs=3)


# Not using ignite
for epoch in range(3):
    data_iter = map(lambda x: x, data_loader)
    while True:
        try:
            batch = next(data_iter)
            print(batch)
        except StopIteration:
            break

Thank you

Environment

Most helpful comment

@vfdev-5 You are so fast :)

All 5 comments

@bhedayat Thank you for the report ! I can reproduce. Let me check.

It seems your codes w/wo ignite are not similar. The code wo ignite should be the following

data = map(lambda x: x, data_loader)
for epoch in range(3):
    data_iter = iter(data)
    while True:
        try:
            batch = next(data_iter)
            print(batch)
        except StopIteration:
            break

The issue occurs in both cases.

@bhedayat you can add a restart iterator in epoch completed handler as in your code snippet:

# Using ignite 
trainer = Engine(update)
pbar = ProgressBar(persist=True)
pbar.attach(trainer, metric_names=["loss"])

@trainer.on(Events.ITERATION_COMPLETED(every=5))
def restart_dataloader():
    print(trainer.state.iteration, "restart_dataloader")
    trainer.state.dataloader = map(lambda x: x, data_loader)

trainer.run(map(lambda x: x, data_loader), epoch_length=5, max_epochs=3)

@vfdev-5 You are so fast :)

It works! Thank you for your help

Was this page helpful?
0 / 5 - 0 ratings

Related issues

alykhantejani picture alykhantejani  ยท  3Comments

vfdev-5 picture vfdev-5  ยท  3Comments

czotti picture czotti  ยท  3Comments

vfdev-5 picture vfdev-5  ยท  3Comments

karfly picture karfly  ยท  4Comments