[英]Convert pytorch tensor to numpy, and reshape
我有一个 pytorch 张量[100, 1, 32, 32]
对应于 100 个图像、1 个通道、高度 32 和宽度 32 的批量大小。我想重塑这个张量以具有尺寸 [32*10, 32*10],这样图像表示为 10x10 网格,前 10 个图像位于第 1 行,依此类推。 如何做到这一点?
我不完全理解你的问题,但试图解决一些问题。
你有一个形状为
[100, 1, 32, 32]
的张量,它代表 100 个形状为[1, 32, 32]
的图像,其中num_channels = 1
,width = 32
,height = 32
。
首先,由于图像只有一个通道,我们可以压缩通道维度。
# image_tensor is of shape [100, 1, 32, 32]
image_tensor = image_tensor.squeeze(1) # [100, 32, 32]
如您所述,我们可以将生成的张量组织成 10 行,每行 10 张图像。
image_tensor = image_tensor.reshape(10, 10, 32, 32)
现在,将生成的张量转换为形状为[32*10, 32*10]
的张量听起来有些不对劲。 但是,让我们做错事,看看我们最终会得到什么。
image_tensor = image_tensor.permute(2, 0, 3, 1) # [32, 10, 32, 10]
排列后,我们得到一个形状为[width, num_rows, height, num_img_in_a_row]
的张量。 然后最后我们可以重塑以获得所需的张量。
image_tensor = image_tensor.reshape(32*10, 32*10)
因此,最终张量的形状为[width * num_rows, height * num_img_in_a_row]
。 你真的想要这个吗? 我不知道如何解释产生的张量!
更新
更高效、更短的版本。 为了避免使用 for 循环,我们可以先置换a
。
import torch
a = torch.arange(9*2*2).view(9,1,2,2)
b = a.permute([0,1,3,2])
torch.cat(torch.split(b, 3),-1).view(6,6).t()
# tensor([[ 0, 1, 4, 5, 8, 9],
# [ 2, 3, 6, 7, 10, 11],
# [12, 13, 16, 17, 20, 21],
# [14, 15, 18, 19, 22, 23],
# [24, 25, 28, 29, 32, 33],
# [26, 27, 30, 31, 34, 35]])
原始答案
您可以使用torch.split
和torch.cat
来实现它。
import torch
a = torch.arange(9*2*2).view(9,1,2,2)
假设我们有a
张量,它是原始张量的迷你版。 看起来,
tensor([[[[ 0, 1],
[ 2, 3]]],
[[[ 4, 5],
[ 6, 7]]],
[[[ 8, 9],
[10, 11]]],
[[[12, 13],
[14, 15]]],
[[[16, 17],
[18, 19]]],
[[[20, 21],
[22, 23]]],
[[[24, 25],
[26, 27]]],
[[[28, 29],
[30, 31]]],
[[[32, 33],
[34, 35]]]])
每个 2x2 子矩阵可以看作一个图像。 您要做的是将前三个图像堆叠到一行,接下来的三个图像到第二行,最后三个图像到第三行。 由于 2x2 子矩阵,“行”实际上有两个暗淡。
three_parts = torch.split(a,3)
torch.cat(torch.split(three_parts[0],1), dim=-1)
#tensor([[[[ 0, 1, 4, 5, 8, 9],
# [ 2, 3, 6, 7, 10, 11]]]])
这里我们只取第一部分。
torch.cat([torch.cat(torch.split(three_parts[i],1),-1) for i in range(3)],0).view(6,6)
# tensor([[ 0, 1, 4, 5, 8, 9],
# [ 2, 3, 6, 7, 10, 11],
# [12, 13, 16, 17, 20, 21],
# [14, 15, 18, 19, 22, 23],
# [24, 25, 28, 29, 32, 33],
# [26, 27, 30, 31, 34, 35]])
您可以使用make_grid()
:
x = torchvision.utils.make_grid(x, nrow=10, padding=0)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.