First off, JAX is a great library, love using it.
My issue is that my PyTorch dataloader freezes whenever I use >0 workers. The dataset itself works with jnp.arrays. How can I fix this? Here's a minimal example of what I'm doing:
import os
import jax.numpy as jnp
import numpy as np
import torch
import torchvision
class SpecialMNIST(torch.utils.data.Dataset):
def __init__(self, train, seed=0):
super().__init__()
self._data = torchvision.datasets.MNIST(os.getcwd(), train, download=True)
self._data_len = len(self._data)
def __getitem__(self, index):
img, label = self._data[index]
return jnp.asarray(np.asarray(img)), label
def __len__(self):
return self._data_len
def collate_fn(batch):
if isinstance(batch[0], jnp.ndarray):
return jnp.stack(batch)
elif isinstance(batch[0], (tuple, list)):
return type(batch[0])(collate_fn(samples) for samples in zip(*batch))
else:
return jnp.asarray(batch)
Then following code works:
dataset = SpecialMNIST(train=False)
dataloader = torch.utils.data.DataLoader(dataset,
collate_fn=collate_fn,
num_workers=0)
next(iter(dataloader))
But the following code just hangs:
dataset = SpecialMNIST(train=False)
dataloader = torch.utils.data.DataLoader(dataset,
collate_fn=collate_fn,
num_workers=1)
next(iter(dataloader))
The problem goes away when I change the jnp arrays to np arrays in the dataloader and collate function. However, in my actual use case, I have some complex data augmentation that I would like to JIT compile and run on the CPU. Is there any way to do this in JAX? Or do I have to stick to normal numpy functions when data loading? Thanks!
I think you can use np.asarray() to convert to normal numpy after the call to jax, this will avoid creating a copy of the data if I am not mistaken.
Thanks for the tip @cgarciae! Even with this change (i.e. returning an np.array and collating np.arrays), as long as I have any JAX calls in my dataset/dataloader, it hangs when using multiple workers.
I've seen this behavior on MacOS which you use threads within processes, not sure if its the same issue thought.
@n2cholas would adding
import torch.multiprocessing as multiprocessing
multiprocessing.set_start_method('spawn')
at the top of your script help? What you're seeing might be due to the XLA state being inconsistent after forking off the data loading processes. It will make their start-up quite a bit more expensive, but it might at least be more correct. Also, please remember to wrap your code in if __name__ == '__main__' if you do that (see the multiprocessing docs).
@apaszke that worked, thanks!
Just a couple of follow up questions regarding using JAX with PyTorch dataloaders:
pin_memory=True in the PyTorch DataLoaders) to help speed up the host->accelerator data transfer?jax.jit(data_aug_fn, backend='cpu'), however, the docs say that this is an "an experimental feature and the API is likely to change". Is there a more idiomatic way of ensuring computation happens on the CPU instead of on the accelerator?My follow up questions were not relevant to this issue, so I will close this issue and ask them in the new discussions tab. Thanks again @apaszke for your help!!
Most helpful comment
@n2cholas would adding
at the top of your script help? What you're seeing might be due to the XLA state being inconsistent after forking off the data loading processes. It will make their start-up quite a bit more expensive, but it might at least be more correct. Also, please remember to wrap your code in
if __name__ == '__main__'if you do that (see themultiprocessingdocs).