简体   繁体   English

PyTorch:图像显示不正确

[英]PyTorch: Image not displaying properly

I have the following code portion:我有以下代码部分:

dataset = trainDataset()
train_loader = DataLoader(dataset,batch_size=1,shuffle=True)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

images = []
image_labels = []

for i, data in enumerate(train_loader,0):
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)
    inputs, labels = inputs.float(), labels.float()
    images.append(inputs)
    image_labels.append(labels)

image = images[7]
image = image.numpy()
image = image.reshape(416,416,3)
img = Image.fromarray(image,'RGB')
img.show()

The issue is that the image doesn't display properly.问题是图像无法正确显示。 For instance, the dataset I have contains images of cats and dogs.例如,我拥有的数据集包含猫和狗的图像。 But, the image displayed looks as shown below.但是,显示的图像如下所示。 Why is that?这是为什么?

在此处输入图片说明

EDIT 1编辑 1

So, after @flawr's nice explanation, I have the following:所以,在@flawr 很好的解释之后,我有以下几点:

image = images[7]
image = image[0,...].permute([1,2,0])
image = image.numpy()
img = Image.fromarray(image,'RGB')
img.show()

And, the image looks as shown below.而且,图像如下所示。 Not sure if it is a Numpy thing or the way the image is represented and displayed?不确定它是 Numpy 的东西还是图像的表示和显示方式? I would like to also kindly note that I get a different display of the image at every run, but it is pretty much something close to the image displayed below.我还想请注意,每次运行时我都会得到不同的图像显示,但它与下面显示的图像非常接近。

在此处输入图片说明

EDIT 2编辑 2

I think the issue now is with how to represent the image.我认为现在的问题是如何表示图像。 By referring to this solution , I now get the following:通过参考这个解决方案,我现在得到以下信息:

image = images[7]
image = image[0,...].permute([1,2,0])
image = image.numpy()
image = (image * 255).astype(np.uint8)
img = Image.fromarray(image,'RGB')
img.show()

Which produces the following image as expected :-)这会按预期产生以下图像:-)

在此处输入图片说明

In pytorch you usually represent pictures with tensors of shape在 pytorch 中,您通常用形状张量表示图片

(channels, height, width)

You then seem to reshape it to what you expect would be然后你似乎将它重塑成你期望的样子

(height, width, channels)

Note that these tensor s or array s are actually stored as 1d "array", and the multiple dimensions just come from defining strides (check out How to understand numpy strides for layman? ).请注意,这些tensorarray实际上存储为一维“数组”,并且多个维度仅来自定义步幅(查看如何理解外行的 numpy 步幅? )。

In your particular case this means that consecutive values (that were basically values of the same color channela and the same row) are now interpreted as different colour channels.在您的特定情况下,这意味着连续值(基本上是相同颜色通道和同一行的值)现在被解释为不同的颜色通道。

So let's say you have a 2x2 image with 3 color channels.因此,假设您有一个带有 3 个颜色通道的 2x2 图像。 Let's say it is a chessboard pattern.假设这是一个棋盘图案。 In pytorch that would looks something like the following array of shape (3, 2, 2) :在 pytorch 中,它看起来类似于以下形状数组(3, 2, 2)

[[[1,0],[0,1]],[[1,0],[0,1]],[[1,0],[0,1]]]

The underlaying internal array is just底层内部数组只是

[  1,0 , 0,1  ,  1,0 , 0,1  ,  1,0 , 0,1  ]

So reshaping to (2, 2, 3) would look like so:所以重塑为(2, 2, 3)看起来像这样:

[[[1,0,0],[1,1,0]],[[0,1,1],[0,0,1]]]

which immediately shows how the image will be completely jumbled.这立即显示图像将如何完全混乱。 Reshaping really just means setting the brackets in different places!重塑真的只是意味着在不同的地方设置括号!

So what you probably want instead of reshape is permute([1, 2, 0]) , (or in numpy called transpose ) which will actually rearrange the data.所以你可能想要而不是reshapepermute([1, 2, 0]) ,(或在 numpy 中称为transpose ),它实际上会重新排列数据。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM