Vision: pytorch pre-trained models preprocessing results 9 images

Created on 16 Jun 2020  路  4Comments  路  Source: pytorch/vision

I am using vgg16 and for preprocessing I use transforms module (as used in the documentation)
and I don't know why, but when it takes my image as input, it outputs 9 small copy of the input image and combines them into one single image (nonetheless the output is correct)

is it a problem?

question

Most helpful comment

The problem lies in this line

plt.imshow(img.cpu().detach().numpy().reshape(img_size))

and specifically in .reshape(img_size). torch works with a channels-first paradigm whereas pyplot expects channels-last. The .reshape is not failing since the number of elements is correct, but you are mixing the channels with the spatial information. Thus, you are seeing the motif multiple times.

To overcome this, you can replace .reshape(img_size) with .squeeze(0).transpose(1, 2, 0). The first part removes the excess batch dimension and the latter moves the channels to the last dimension.

If you are doing this conversion solely to show the image, the ToPilImage might be helpful.

All 4 comments

Hi @aliamiri1380, please post a minimal code example of what your are doing. Otherwise we can only give you general answers. In this case: I don't know if it is a problem within your workflow, but it is definitively unexpected.

Hi, this is my code

import torch as T
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
_ = T.hub.load('pytorch/vision:v0.6.0', 'resnext101_32x8d', pretrained=True)
_.eval()
def preprocess(model, input_image):
    img_size = (256,256,3)
    img = Image.open(input_image)
    img = img.resize(img_size[:-1]).convert('RGB')
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.ColorJitter((1,2),(1,2),(1,2),(0, .5)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    img = preprocess(img).unsqueeze(0).to(device)
    plt.imshow(img.cpu().detach().numpy().reshape(img_size))
    plt.show()
    model.to(device)
    with T.no_grad():
        output = model(img)
    return output

preprocess(_, 'bmw.jpeg').argmax(-1).item()

and my input image and outputted image respectively
image

image

after preprocessing, the image becomes 9 copy of the input

The problem lies in this line

plt.imshow(img.cpu().detach().numpy().reshape(img_size))

and specifically in .reshape(img_size). torch works with a channels-first paradigm whereas pyplot expects channels-last. The .reshape is not failing since the number of elements is correct, but you are mixing the channels with the spatial information. Thus, you are seeing the motif multiple times.

To overcome this, you can replace .reshape(img_size) with .squeeze(0).transpose(1, 2, 0). The first part removes the excess batch dimension and the latter moves the channels to the last dimension.

If you are doing this conversion solely to show the image, the ToPilImage might be helpful.

thank you for the answer, I didn't realize this keynote

Was this page helpful?
0 / 5 - 0 ratings

Related issues

fmassa picture fmassa  路  30Comments

dssa56 picture dssa56  路  60Comments

timonbimon picture timonbimon  路  28Comments

lpuglia picture lpuglia  路  44Comments

soldierofhell picture soldierofhell  路  36Comments