It'd be great if we could have a function of the form
def random_split(dataset, lengths):
that takes a dataset, and a list 'lengths' of ints that sum to dataset.length; the function returns a list of datasets with 'len(lengths)' many elements, with the according random split of the dataset.
There are many ways to do this, but all I thought of are hacky as hell (e.g. creating two datasets of indices that then map to the original dataset in a nonstandard way).
Any ideas?
Sth like this should work, and doesn't seem very hacky:
from itertools import accumulate
class Subset(torch.utils.data.Dataset):
def __init__(self, dataset, indices):
self.dataset = dataset
self.indices = indices
def __getitem__(self, idx):
return self.dataset[self.indices[idx]]
def __len__(self):
return len(self.indices)
def random_split(dataset, lengths):
assert sum(lengths) == len(dataset)
indices = torch.randperm(lengths)
return [Subset(dataset, indices[offset - length:offset])
for offset, length in accumulate(lengths), lengths]
Thanks!! You're the best :D
I believe such functionality could live under pytorch or in a 3rdparty library? I think this is general enough and not only necessary for vision tasks. Wdyt?
I agree that this would be better placed under torch.utils.data somewhere as it is useful for other types of datasets too. I've opened a PR in pytorch so will close this issue in vision
Most helpful comment
Sth like this should work, and doesn't seem very hacky: