简体   繁体   English

将 pytorch 张量转换为 numpy,并重塑

[英]Convert pytorch tensor to numpy, and reshape

I have a pytorch tensor [100, 1, 32, 32] corresponding to batch size of 100 images, 1 channel, height 32 and width 32. I want to reshape this tensor to have dimension [32*10, 32*10], such that the images are represented as a 10x10 grid, with the first 10 images on row 1, and so on.我有一个 pytorch 张量[100, 1, 32, 32]对应于 100 个图像、1 个通道、高度 32 和宽度 32 的批量大小。我想重塑这个张量以具有尺寸 [32*10, 32*10],这样图像表示为 10x10 网格,前 10 个图像位于第 1 行,依此类推。 How to achieve this?如何做到这一点?

I didn't understand your question completely but trying to address some of the concerns.我不完全理解你的问题,但试图解决一些问题。

You have a tensor of shape [100, 1, 32, 32] that represents 100 images of shape [1, 32, 32] where num_channels = 1 , width = 32 , height = 32 .你有一个形状为[100, 1, 32, 32]的张量,它代表 100 个形状为[1, 32, 32]的图像,其中num_channels = 1width = 32height = 32

First, since the images have only one channel, we can squeeze the channel dimension.首先,由于图像只有一个通道,我们可以压缩通道维度。

# image_tensor is of shape [100, 1, 32, 32]
image_tensor = image_tensor.squeeze(1) # [100, 32, 32]

We can organize the resulting tensors into 10 rows of 10 images as you described.如您所述,我们可以将生成的张量组织成 10 行,每行 10 张图像。

image_tensor = image_tensor.reshape(10, 10, 32, 32)

Now, converting the resulting tensor into a tensor of shape [32*10, 32*10] sounds something is wrong.现在,将生成的张量转换为形状为[32*10, 32*10]的张量听起来有些不对劲。 But, let's do that wrong thing and see what we end up with.但是,让我们做错事,看看我们最终会得到什么。

image_tensor = image_tensor.permute(2, 0, 3, 1) # [32, 10, 32, 10]

After permutation, we get a tensor of shape [width, num_rows, height, num_img_in_a_row] .排列后,我们得到一个形状为[width, num_rows, height, num_img_in_a_row]的张量。 Then finally we can reshape to get the desired tensor.然后最后我们可以重塑以获得所需的张量。

image_tensor = image_tensor.reshape(32*10, 32*10)

So, the final tensor is of shape [width * num_rows, height * num_img_in_a_row] .因此,最终张量的形状为[width * num_rows, height * num_img_in_a_row] Do you really want this?你真的想要这个吗? I am not sure how to interpret the resulting tensor!!我不知道如何解释产生的张量!

Update更新

More efficient and shorter version.更高效、更短的版本。 To avoid using for-loop, we can permute a first.为了避免使用 for 循环,我们可以先置换a

import torch
a = torch.arange(9*2*2).view(9,1,2,2)
b = a.permute([0,1,3,2])
torch.cat(torch.split(b, 3),-1).view(6,6).t()
# tensor([[ 0,  1,  4,  5,  8,  9],
#         [ 2,  3,  6,  7, 10, 11],
#         [12, 13, 16, 17, 20, 21],
#         [14, 15, 18, 19, 22, 23],
#         [24, 25, 28, 29, 32, 33],
#         [26, 27, 30, 31, 34, 35]])

Original Answer原始答案

You can use torch.split and torch.cat to implement it.您可以使用torch.splittorch.cat来实现它。

import torch
a = torch.arange(9*2*2).view(9,1,2,2)

Assuming we have a tensor, which is a mini version of your original tensor.假设我们有a张量,它是原始张量的迷你版。 And it looks like,看起来,

tensor([[[[ 0,  1],
          [ 2,  3]]],
        [[[ 4,  5],
          [ 6,  7]]],
        [[[ 8,  9],
          [10, 11]]],
        [[[12, 13],
          [14, 15]]],
        [[[16, 17],
          [18, 19]]],
        [[[20, 21],
          [22, 23]]],
        [[[24, 25],
          [26, 27]]],
        [[[28, 29],
          [30, 31]]],
        [[[32, 33],
          [34, 35]]]])

Each 2x2 sub-matrix can be seen as one image.每个 2x2 子矩阵可以看作一个图像。 What you want to do is stacking the first three images to one row, next three images to the second row, and last three images to the third row.您要做的是将前三个图像堆叠到一行,接下来的三个图像到第二行,最后三个图像到第三行。 The "row" has actually two dim due to the 2x2 sub-matrix.由于 2x2 子矩阵,“行”实际上有两个暗淡。

three_parts = torch.split(a,3)

torch.cat(torch.split(three_parts[0],1), dim=-1)

#tensor([[[[ 0,  1,  4,  5,  8,  9],
#          [ 2,  3,  6,  7, 10, 11]]]])

Here we only take the first part.这里我们只取第一部分。

torch.cat([torch.cat(torch.split(three_parts[i],1),-1) for i in range(3)],0).view(6,6)
# tensor([[ 0,  1,  4,  5,  8,  9],
#         [ 2,  3,  6,  7, 10, 11],
#         [12, 13, 16, 17, 20, 21],
#         [14, 15, 18, 19, 22, 23],
#         [24, 25, 28, 29, 32, 33],
#         [26, 27, 30, 31, 34, 35]])

You can use make_grid() :您可以使用make_grid()

x = torchvision.utils.make_grid(x, nrow=10, padding=0)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM