简体   繁体   English

如何在 PyTorch 中显示单个图像?

[英]How do I display a single image in PyTorch?

How do I display a PyTorch Tensor of shape (3, 224, 224) representing a 224x224 RGB image?如何显示代表 224x224 RGB 图像的形状(3, 224, 224)的 PyTorch Tensor Using plt.imshow(image) gives the error:使用plt.imshow(image)会出现错误:

TypeError: Invalid dimensions for image data TypeError:图像数据的尺寸无效

Given a Tensor representing the image, use .permute() to put the channels as the last dimension:给定一个表示图像的Tensor ,使用.permute()将通道作为最后一个维度:

plt.imshow(  tensor_image.permute(1, 2, 0)  )

Note: permute does not copy or allocate memory , and from_numpy() doesn't either.注意: permute不会复制或分配 memoryfrom_numpy()也不会。

As you can see matplotlib works fine even without conversion to numpy array.如您所见,即使没有转换为numpy数组, matplotlib也能正常工作。 But PyTorch Tensors ("Image tensors") are channel first, so to use them with matplotlib you need to reshape it:但是 PyTorch 张量(“图像张量”)首先是通道,因此要将它们与matplotlib一起使用,您需要对其进行重塑:

Code:代码:

from scipy.misc import face
import matplotlib.pyplot as plt
import torch

np_image = face()
print(type(np_image), np_image.shape)
tensor_image = torch.from_numpy(np_image)
print(type(tensor_image), tensor_image.shape)
# reshape to channel first:
tensor_image = tensor_image.view(tensor_image.shape[2], tensor_image.shape[0], tensor_image.shape[1])
print(type(tensor_image), tensor_image.shape)

# If you try to plot image with shape (C, H, W)
# You will get TypeError:
# plt.imshow(tensor_image)

# So we need to reshape it to (H, W, C):
tensor_image = tensor_image.view(tensor_image.shape[1], tensor_image.shape[2], tensor_image.shape[0])
print(type(tensor_image), tensor_image.shape)

plt.imshow(tensor_image)
plt.show()

Output:输出:

<class 'numpy.ndarray'> (768, 1024, 3)
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
<class 'torch.Tensor'> torch.Size([3, 768, 1024])
<class 'torch.Tensor'> torch.Size([768, 1024, 3])

Given the image is loaded as described and stored in the variable image :假设图像按描述加载并存储在变量image中:

plt.imshow(transforms.ToPILImage()(image), interpolation="bicubic")
#transforms.ToPILImage()(image).show() # Alternatively

Or as Soumith suggested :或如Soumith 建议的那样:

 def show(img): npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')

PyTorch modules processing image data expect tensors in the format C × H × W .处理图像数据的 PyTorch 模块需要C × H × W格式的张量。 1 1
Whereas PILLow and Matplotlib expect image arrays in the format H × W × C .而 PILlow 和 Matplotlib 需要格式为H × W × C的图像数组。 2 2

You can easily convert tensors to/ from this format with a TorchVision transform:您可以使用 TorchVision 转换轻松地将张量转换为/ 该格式转换:

from torchvision import transforms.functional as F

F.to_pil_image(image_tensor)

Or by directly permuting the axes:或者通过直接排列轴:

image_tensor.permute(1,2,0)

  1. PyTorch modules dealing with image data require tensors to be laid out as C × H × W : channels, height, and width, respectively.处理图像数据的 PyTorch 模块需要将张量布局为C × H × W :通道、高度和宽度。

  2. Note how we have to use permute to change the order of the axes from C × H × W to H × W × C to match what Matplotlib expects.请注意我们必须如何使用permute将轴的顺序从C × H × W更改为H × W × C以匹配 Matplotlib 的预期。

A complete example given an image pathname img_path :给出图像路径名img_path的完整示例:

from PIL import Image
image = Image.open(img_path)
plt.imshow(transforms.ToPILImage()(transforms.ToTensor()(image)), interpolation="bicubic")

Note that transforms.* return a class , which is why the funky bracketing.请注意, transforms.*返回一个class ,这就是时髦括号的原因。

Use show_image from fastai使用 fastai 的 show_image

from fastai.vision.all import show_image

在此处输入图像描述

在此处输入图像描述

I've written a simple function to visualize the pytorch tensor using matplotlib.我编写了一个简单的函数来使用 matplotlib 可视化 pytorch 张量。

import numpy as np
import matplotlib.pyplot as plt
import torch

def show(*imgs):
    '''
     input imgs can be single or multiple tensor(s), this function uses matplotlib to visualize.
     Single input example:
     show(x) gives the visualization of x, where x should be a torch.Tensor
        if x is a 4D tensor (like image batch with the size of b(atch)*c(hannel)*h(eight)*w(eight), this function splits x in batch dimension, showing b subplots in total, where each subplot displays first 3 channels (3*h*w) at most. 
        if x is a 3D tensor, this function shows first 3 channels at most (in RGB format)
        if x is a 2D tensor, it will be shown as grayscale map
     
     Multiple input example:      
     show(x,y,z) produces three windows, displaying x, y, z respectively, where x,y,z can be in any form described above.
    '''
    img_idx = 0
    for img in imgs:
        img_idx +=1
        plt.figure(img_idx)
        if isinstance(img, torch.Tensor):
            img = img.detach().cpu()

            if img.dim()==4: # 4D tensor
                bz = img.shape[0]
                c = img.shape[1]
                if bz==1 and c==1:  # single grayscale image
                    img=img.squeeze()
                elif bz==1 and c==3: # single RGB image
                    img=img.squeeze()
                    img=img.permute(1,2,0)
                elif bz==1 and c > 3: # multiple feature maps
                    img = img[:,0:3,:,:]
                    img = img.permute(0, 2, 3, 1)[:]
                    print('warning: more than 3 channels! only channels 0,1,2 are preserved!')
                elif bz > 1 and c == 1:  # multiple grayscale images
                    img=img.squeeze()
                elif bz > 1 and c == 3:  # multiple RGB images
                    img = img.permute(0, 2, 3, 1)
                elif bz > 1 and c > 3:  # multiple feature maps
                    img = img[:,0:3,:,:]
                    img = img.permute(0, 2, 3, 1)[:]
                    print('warning: more than 3 channels! only channels 0,1,2 are preserved!')
                else:
                    raise Exception("unsupported type!  " + str(img.size()))
            elif img.dim()==3: # 3D tensor
                bz = 1
                c = img.shape[0]
                if c == 1:  # grayscale
                    img=img.squeeze()
                elif c == 3:  # RGB
                    img = img.permute(1, 2, 0)
                else:
                    raise Exception("unsupported type!  " + str(img.size()))
            elif img.dim()==2:
                pass
            else:
                raise Exception("unsupported type!  "+str(img.size()))


            img = img.numpy()  # convert to numpy
            img = img.squeeze()
            if bz ==1:
                plt.imshow(img, cmap='gray')
                # plt.colorbar()
                # plt.show()
            else:
                for idx in range(0,bz):
                    plt.subplot(int(bz**0.5),int(np.ceil(bz/int(bz**0.5))),int(idx+1))
                    plt.imshow(img[idx], cmap='gray')

        else:
            raise Exception("unsupported type:  "+str(type(img)))

Torch is in shape of channel,height,width need to convert it into height,width, channel so permute.火炬是通道的形状,高度,宽度需要将其转换为高度,宽度,通道所以置换。

plt.imshow(white_torch.permute(1, 2, 0))

Or directly if you want或者如果你想直接

import torch
import torchvision
from torchvision.io import read_image
import torchvision.transforms as T

!wget 'https://images.unsplash.com/photo-1553284965-83fd3e82fa5a?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxleHBsb3JlLWZlZWR8NHx8fGVufDB8fHx8&w=1000&q=80'  -O white_horse.jpg

white_torch = torchvision.io.read_image('white_horse.jpg')

T.ToPILImage()(white_torch)

在此处输入图像描述

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM