Jax: Multiple Workers with PyTorch Dataloader

Created on 9 Jun 2020  路  6Comments  路  Source: google/jax

First off, JAX is a great library, love using it.

My issue is that my PyTorch dataloader freezes whenever I use >0 workers. The dataset itself works with jnp.arrays. How can I fix this? Here's a minimal example of what I'm doing:

import os
import jax.numpy as jnp
import numpy as np
import torch
import torchvision


class SpecialMNIST(torch.utils.data.Dataset):
    def __init__(self, train, seed=0):
        super().__init__()
        self._data = torchvision.datasets.MNIST(os.getcwd(), train, download=True)
        self._data_len = len(self._data)

    def __getitem__(self, index):
        img, label = self._data[index]
        return jnp.asarray(np.asarray(img)), label

    def __len__(self):
        return self._data_len


def collate_fn(batch):
    if isinstance(batch[0], jnp.ndarray):
        return jnp.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        return type(batch[0])(collate_fn(samples) for samples in zip(*batch))
    else:
        return jnp.asarray(batch)

Then following code works:

dataset = SpecialMNIST(train=False)
dataloader = torch.utils.data.DataLoader(dataset, 
                                         collate_fn=collate_fn,
                                         num_workers=0)
next(iter(dataloader))

But the following code just hangs:

dataset = SpecialMNIST(train=False)
dataloader = torch.utils.data.DataLoader(dataset, 
                                         collate_fn=collate_fn,
                                         num_workers=1)
next(iter(dataloader))

The problem goes away when I change the jnp arrays to np arrays in the dataloader and collate function. However, in my actual use case, I have some complex data augmentation that I would like to JIT compile and run on the CPU. Is there any way to do this in JAX? Or do I have to stick to normal numpy functions when data loading? Thanks!

Most helpful comment

@n2cholas would adding

import torch.multiprocessing as multiprocessing
multiprocessing.set_start_method('spawn')

at the top of your script help? What you're seeing might be due to the XLA state being inconsistent after forking off the data loading processes. It will make their start-up quite a bit more expensive, but it might at least be more correct. Also, please remember to wrap your code in if __name__ == '__main__' if you do that (see the multiprocessing docs).

All 6 comments

I think you can use np.asarray() to convert to normal numpy after the call to jax, this will avoid creating a copy of the data if I am not mistaken.

Thanks for the tip @cgarciae! Even with this change (i.e. returning an np.array and collating np.arrays), as long as I have any JAX calls in my dataset/dataloader, it hangs when using multiple workers.

I've seen this behavior on MacOS which you use threads within processes, not sure if its the same issue thought.

@n2cholas would adding

import torch.multiprocessing as multiprocessing
multiprocessing.set_start_method('spawn')

at the top of your script help? What you're seeing might be due to the XLA state being inconsistent after forking off the data loading processes. It will make their start-up quite a bit more expensive, but it might at least be more correct. Also, please remember to wrap your code in if __name__ == '__main__' if you do that (see the multiprocessing docs).

@apaszke that worked, thanks!

Just a couple of follow up questions regarding using JAX with PyTorch dataloaders:

  1. Is there any way to move the arrays to pinned memory (like pin_memory=True in the PyTorch DataLoaders) to help speed up the host->accelerator data transfer?
  2. To ensure my data augmentation happens on the CPU, I'm using jax.jit(data_aug_fn, backend='cpu'), however, the docs say that this is an "an experimental feature and the API is likely to change". Is there a more idiomatic way of ensuring computation happens on the CPU instead of on the accelerator?

My follow up questions were not relevant to this issue, so I will close this issue and ask them in the new discussions tab. Thanks again @apaszke for your help!!

Was this page helpful?
0 / 5 - 0 ratings