The RuntimeError that occurs in the DistributedProxySampler on line 241 shouldn't be there since the indices are padded with the full sample which was updated because of this comment.
conda, pip, source):@ryanwongsa thanks for reporting. Can you provide any reproducible example of this error occured ? Thanks
Taking the example from the unit test and setting the num_replicas to 8 produces the error
from ignite.distributed.auto import DistributedProxySampler
import torch
from torch.utils.data import WeightedRandomSampler
weights = torch.ones(100)
weights[:50] += 1
num_samples = 100
sampler = WeightedRandomSampler(weights, num_samples)
num_replicas = 8
dist_samplers = [DistributedProxySampler(sampler, num_replicas=num_replicas, rank=i) for i in range(num_replicas)]
torch.manual_seed(0)
true_indices = list(sampler)
indices_per_rank = []
for s in dist_samplers:
s.set_epoch(0)
indices_per_rank += list(s)
assert set(indices_per_rank) == set(true_indices)
The error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-3-d02cd2dd1018> in <module>
17 for s in dist_samplers:
18 s.set_epoch(0)
---> 19 indices_per_rank += list(s)
/opt/conda/lib/python3.7/site-packages/ignite/distributed/auto.py in __iter__(self)
240
241 if len(indices) != self.total_size:
--> 242 raise RuntimeError("{} vs {}".format(len(indices), self.total_size))
243
244 # subsample
RuntimeError: 200 vs 104
The assert will fail too after fixing the RuntimeError but that is because of the padding.
@ryanwongsa thanks for the repro snippet ! Well, there are two points there:
1) len(indices) can be indeed larger than self.total_size and this can be fixed by adding for example
if len(indices) > self.total_size:
indices = indices[:self.total_size]
if len(indices) != self.total_size:
raise RuntimeError("{} vs {}".format(len(indices), self.total_size))
2) test on 100 samples seems like a small test and depends on seed. While setting num_samples=200, test passes
import torch
from torch.utils.data import WeightedRandomSampler
weights = torch.ones(100)
weights[:50] += 1
num_samples = 200
sampler = WeightedRandomSampler(weights, num_samples)
num_replicas = 8
dist_samplers = [FixedDistributedProxySampler(sampler, num_replicas=num_replicas, rank=i) for i in range(num_replicas)]
for seed in range(100):
torch.manual_seed(seed)
true_indices = list(sampler)
indices_per_rank = []
for s in dist_samplers:
s.set_epoch(seed)
indices_per_rank += list(s)
set_indices_per_rank = set(indices_per_rank)
set_true_indices = set(true_indices)
assert set_indices_per_rank == set_true_indices, "{} | {}".format(set_true_indices - set_indices_per_rank, set_indices_per_rank - set_true_indices)
(here FixedDistributedProxySampler is DistributedProxySampler with added 2 lines from 1) ).
What do you think ?
PS: True that for num_samples=100, the condition
assert set(indices_per_rank) == set(true_indices)
can be not satisfied as indices_per_rank contains more generated samples than true_indices. A more fair comparision should be done with true_indices being 104 instead of 100 and being sampled twice as it is done inside DistributedProxySampler.
@vfdev-5 Thanks for the response
len(indices) != self.total_size seems a bit redundant as the previous line in the sampler currently is while len(indices) < self.total_size:
indices += list(self.sampler)
So it guarantees that the len(indices)==self.total_size will be met if you add:
if len(indices) > self.total_size:
indices = indices[:self.total_size]
I could be wrong but it doesn't seem useful to have the RuntimeError check in place and adding the indices = indices[:self.total_size] component.
import torch
from torch.utils.data import WeightedRandomSampler
weights = torch.ones(100)
weights[:50] += 1
num_samples = 100
sampler = WeightedRandomSampler(weights, num_samples)
num_replicas = 8
dist_samplers = [FixedDistributedProxySampler(sampler, num_replicas=num_replicas, rank=i) for i in range(num_replicas)]
torch.manual_seed(0)
true_indices = list(sampler)[0:num_samples//num_replicas*num_replicas]
indices_per_rank = []
for s in dist_samplers:
s.set_epoch(0)
indices_per_rank += list(s)[0:num_samples//num_replicas]
assert set(indices_per_rank) == set(true_indices)
I am not sure if it is the correct way to test it but it solves the test issue when num_replicas is not divisible by num_samples.
I could be wrong but it doesn't seem useful to have the RuntimeError check in place and adding the indices = indices[:self.total_size] component.
@ryanwongsa yes, you are right about, this check is redundant and we can remove it if we cut the number of indices as proposed.
I am not sure if it is the correct way to test it but it solves the test issue when num_replicas is not divisible by num_samples.
Yes, this could work like that. However, what is important is the samples seen by the model during the training. What happens here is that on a single device model during a single epoch will see 100 samples and distributed model will see 104 samples.
All samples (100 and 104) are generated with desired sampler and should normally follow intended distributions. But in case of DDP, it will see 4 samples more. Probably, it is OK. Anyway, for large datasets, same thing can happend and DDP model will see at most (world_size - 1) more samples for a single epoch and IMO it is not a problem...
If you would like to help us to land a fix for this issue a bit faster, feel free to send a PR :) Otherwise, we'll fix it in coming days...
@vfdev-5 I created PR #1192 with the changes you described above. The test has also been updated to reflect the newer test you described earlier. I am not sure if the PR how you wanted it so would be good to get feedback. Thanks