import torch
import torchvision
import matplotlib.pyplot as plt
from torch.autograd import Variable
from torchvision import datasets, transforms
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])
data_train=datasets.MNIST(root='./data/',
transform=transform,
train=True,
download=True)
data_test=datasets.MNIST(root='./data/',
transform=transform,
train=False)
data_loader_train=torch.utils.data.DataLoader(dataset=data_train,
batch_size=64,
shuffle=True)
data_loader_test=torch.utils.data.DataLoader(dataset=data_test,
batch_size=64,
shuffle=True)
images, labels = next(iter(data_loader_train))
print(images.shape)
img=torchvision.utils.make_grid(images)
img=img.numpy().transpose(1,2,0)
mean=[0.5, 0.5, 0.5]
std=[0.5, 0.5, 0.5]
img=img*std+mean
print([labels[i] for i in range(64)])
img.imshow()
In the code above, images is a [64, 1, 28, 28] tensor. After executing, it should return a [1, 242, 242] tensor. But it return a [3, 242, 242] tensor.
This is done on purpose in https://github.com/pytorch/vision/blob/ccbb3221b7f0637f1706df29d2c2995e9d5171bf/torchvision/utils.py#L44-L45
so that all the images returned by make_grid follow the same format.
You can always get the original image back by slicing over the first dimension though, so I don't see any harm in it.
I got the idea, but in my opinion, the consistency of the channel may be better. Thank you anyway.
Most helpful comment
I got the idea, but in my opinion, the consistency of the channel may be better. Thank you anyway.