My dataloader is returning a list of lists (for multi-label classification) for labels and a tensor of images for each batch. When I'm using DataParallel mode, labels are not getting split into "sub-batches" and I'm getting all the labels on each GPU. Is there a way to implement this splitting also for non-tensors?
class CustomDataset(Dataset):
...
def collate(self, batch):
images, labels = list(zip(*batch))
return torch.stack(images), [ label_set for label_set in labels ]
class LitModel(pl.LightningModule):
...
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
print(f'Val loss {y.shape}, {y_hat.shape}' )
This example will print y to have length of full batch and y_hat to have length the same as the length of x tensor (which is smaller here). x tensor got split correctly, while y didn't.
Is it the issue of lightning or maybe DataParallel module?
Hi! thanks for your contribution!, great first issue!
It is the expected behavior of the DataParallel module.
When something is not a Tensor it will duplicate it for all GPUs.
@BartekRoszak @zlenyk
I believe that is the expected behaviour. Lightning uses PyTorch's scatter to distribute the input across GPUs https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py#L11
PyTorch's scatter will recurse into nested lists and treat each Tensor at the bottom as the data batch. So it's treating each of your label as the whole batch.
The best way to work around this would be to covert it to a tensor. torch.tensor([ label_set for label_set in labels ]) should work fine
I see, thank you guys for the answers.