Ignite: DistributedProxySampler RuntimeError when indices are padded

Created on 9 Jul 2020  路  6Comments  路  Source: pytorch/ignite

馃悰 Bug description

The RuntimeError that occurs in the DistributedProxySampler on line 241 shouldn't be there since the indices are padded with the full sample which was updated because of this comment.

Environment

  • PyTorch Version (e.g., 1.4):
  • Ignite Version (e.g., 0.3.0):
  • OS (e.g., Linux):
  • How you installed Ignite (conda, pip, source):
  • Python version:
  • Any other relevant information:
bug

All 6 comments

@ryanwongsa thanks for reporting. Can you provide any reproducible example of this error occured ? Thanks

Taking the example from the unit test and setting the num_replicas to 8 produces the error

from ignite.distributed.auto import DistributedProxySampler
import torch
from torch.utils.data import WeightedRandomSampler

weights = torch.ones(100)
weights[:50] += 1
num_samples = 100
sampler = WeightedRandomSampler(weights, num_samples)

num_replicas = 8
dist_samplers = [DistributedProxySampler(sampler, num_replicas=num_replicas, rank=i) for i in range(num_replicas)]

torch.manual_seed(0)
true_indices = list(sampler)

indices_per_rank = []
for s in dist_samplers:
    s.set_epoch(0)
    indices_per_rank += list(s)

assert set(indices_per_rank) == set(true_indices)

The error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-d02cd2dd1018> in <module>
     17 for s in dist_samplers:
     18     s.set_epoch(0)
---> 19     indices_per_rank += list(s)

/opt/conda/lib/python3.7/site-packages/ignite/distributed/auto.py in __iter__(self)
    240 
    241         if len(indices) != self.total_size:
--> 242             raise RuntimeError("{} vs {}".format(len(indices), self.total_size))
    243 
    244         # subsample

RuntimeError: 200 vs 104

The assert will fail too after fixing the RuntimeError but that is because of the padding.

@ryanwongsa thanks for the repro snippet ! Well, there are two points there:
1) len(indices) can be indeed larger than self.total_size and this can be fixed by adding for example

        if len(indices) > self.total_size:
            indices = indices[:self.total_size]

        if len(indices) != self.total_size:
            raise RuntimeError("{} vs {}".format(len(indices), self.total_size))

2) test on 100 samples seems like a small test and depends on seed. While setting num_samples=200, test passes

import torch
from torch.utils.data import WeightedRandomSampler

weights = torch.ones(100)
weights[:50] += 1
num_samples = 200
sampler = WeightedRandomSampler(weights, num_samples)

num_replicas = 8
dist_samplers = [FixedDistributedProxySampler(sampler, num_replicas=num_replicas, rank=i) for i in range(num_replicas)]


for seed in range(100):
    torch.manual_seed(seed)
    true_indices = list(sampler)

    indices_per_rank = []
    for s in dist_samplers:
        s.set_epoch(seed)
        indices_per_rank += list(s)


    set_indices_per_rank = set(indices_per_rank)
    set_true_indices = set(true_indices)
    assert set_indices_per_rank == set_true_indices, "{} | {}".format(set_true_indices - set_indices_per_rank, set_indices_per_rank - set_true_indices)

(here FixedDistributedProxySampler is DistributedProxySampler with added 2 lines from 1) ).

What do you think ?

PS: True that for num_samples=100, the condition

assert set(indices_per_rank) == set(true_indices)

can be not satisfied as indices_per_rank contains more generated samples than true_indices. A more fair comparision should be done with true_indices being 104 instead of 100 and being sampled twice as it is done inside DistributedProxySampler.

@vfdev-5 Thanks for the response

  1. I think adding the extra code to check the condition len(indices) != self.total_size seems a bit redundant as the previous line in the sampler currently is
while len(indices) < self.total_size:
    indices += list(self.sampler)

So it guarantees that the len(indices)==self.total_size will be met if you add:

if len(indices) > self.total_size:
    indices = indices[:self.total_size]

I could be wrong but it doesn't seem useful to have the RuntimeError check in place and adding the indices = indices[:self.total_size] component.

  1. Seems good for updating the test, although I think the original test was good enough. To probably test that 1. was fixed properly in a unit test, I think having a test like:
import torch
from torch.utils.data import WeightedRandomSampler

weights = torch.ones(100)
weights[:50] += 1
num_samples = 100
sampler = WeightedRandomSampler(weights, num_samples)

num_replicas = 8
dist_samplers = [FixedDistributedProxySampler(sampler, num_replicas=num_replicas, rank=i) for i in range(num_replicas)]

torch.manual_seed(0)
true_indices = list(sampler)[0:num_samples//num_replicas*num_replicas]

indices_per_rank = []
for s in dist_samplers:
    s.set_epoch(0)
    indices_per_rank += list(s)[0:num_samples//num_replicas]

assert set(indices_per_rank) == set(true_indices)

I am not sure if it is the correct way to test it but it solves the test issue when num_replicas is not divisible by num_samples.

I could be wrong but it doesn't seem useful to have the RuntimeError check in place and adding the indices = indices[:self.total_size] component.

@ryanwongsa yes, you are right about, this check is redundant and we can remove it if we cut the number of indices as proposed.

I am not sure if it is the correct way to test it but it solves the test issue when num_replicas is not divisible by num_samples.

Yes, this could work like that. However, what is important is the samples seen by the model during the training. What happens here is that on a single device model during a single epoch will see 100 samples and distributed model will see 104 samples.
All samples (100 and 104) are generated with desired sampler and should normally follow intended distributions. But in case of DDP, it will see 4 samples more. Probably, it is OK. Anyway, for large datasets, same thing can happend and DDP model will see at most (world_size - 1) more samples for a single epoch and IMO it is not a problem...

If you would like to help us to land a fix for this issue a bit faster, feel free to send a PR :) Otherwise, we'll fix it in coming days...

@vfdev-5 I created PR #1192 with the changes you described above. The test has also been updated to reflect the newer test you described earlier. I am not sure if the PR how you wanted it so would be good to get feedback. Thanks

Was this page helpful?
0 / 5 - 0 ratings

Related issues

vfdev-5 picture vfdev-5  路  3Comments

andreydung picture andreydung  路  4Comments

sisp picture sisp  路  3Comments

elanmart picture elanmart  路  4Comments

czotti picture czotti  路  3Comments