Vision: [Feature request] Random dataset split

Created on 13 Dec 2017  路  4Comments  路  Source: pytorch/vision

It'd be great if we could have a function of the form

def random_split(dataset, lengths):

that takes a dataset, and a list 'lengths' of ints that sum to dataset.length; the function returns a list of datasets with 'len(lengths)' many elements, with the according random split of the dataset.

There are many ways to do this, but all I thought of are hacky as hell (e.g. creating two datasets of indices that then map to the original dataset in a nonstandard way).

Any ideas?

Most helpful comment

Sth like this should work, and doesn't seem very hacky:

from itertools import accumulate


class Subset(torch.utils.data.Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)


def random_split(dataset, lengths):
    assert sum(lengths) == len(dataset)
    indices = torch.randperm(lengths)
    return [Subset(dataset, indices[offset - length:offset])
            for offset, length in accumulate(lengths), lengths]

All 4 comments

Sth like this should work, and doesn't seem very hacky:

from itertools import accumulate


class Subset(torch.utils.data.Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)


def random_split(dataset, lengths):
    assert sum(lengths) == len(dataset)
    indices = torch.randperm(lengths)
    return [Subset(dataset, indices[offset - length:offset])
            for offset, length in accumulate(lengths), lengths]

Thanks!! You're the best :D

I believe such functionality could live under pytorch or in a 3rdparty library? I think this is general enough and not only necessary for vision tasks. Wdyt?

I agree that this would be better placed under torch.utils.data somewhere as it is useful for other types of datasets too. I've opened a PR in pytorch so will close this issue in vision

Was this page helpful?
0 / 5 - 0 ratings

Related issues

Abolfazl-Mehranian picture Abolfazl-Mehranian  路  3Comments

bodokaiser picture bodokaiser  路  3Comments

zhang-zhenyu picture zhang-zhenyu  路  3Comments

300LiterPropofol picture 300LiterPropofol  路  3Comments

a-maci picture a-maci  路  3Comments