繁体   English   中英

使用 matplotlib/pytorch 从批次中创建张量的图像网格

[英]Creating an image grid of tensors out of a batch using matplotlib/pytorch

我正在尝试从一批张量中创建一个图像网格(例如 3 x 3),这些张量将在下一步通过数据加载器输入到 GAN 中。 使用下面的代码,我能够将张量转换为显示在正确位置的网格中的图像。 问题是,它们都显示在一个单独的网格中,如下所示:图 1图 2图 5 如何将它们放在一个网格中,然后只返回一个包含所有 9 个图像的图形? 也许我把它弄得太复杂了。 :D 最后,real_samples 中的张量必须被转换并放入网格中。

real_samples = next(iter(train_loader))
for i in range(9):
    plt.figure(figsize=(9, 9))
    plt.subplot(330 + i + 1)
    plt.imshow(np.transpose(vutils.make_grid(real_samples[i].to(device)
                      [:40], padding=1, normalize=True).cpu(),(1,2,0)))
    plt.show()

以下是如何使用 matplotlib XD 显示可变数量的精彩 CryptoPunk:

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import matplotlib.image as mpimg

row_count = 3
col_count = 3

cryptopunks = [
    mpimg.imread(f"cryptopunks/{i}.png") for i in range(row_count * col_count)
]

fig = plt.figure(figsize=(8., 8.))
grid = ImageGrid(fig, 111, nrows_ncols=(row_count, col_count), axes_pad=0.1)

for ax, im in zip(grid, cryptopunks):
    ax.imshow(im)

plt.show()

在此处输入图像描述

请注意,代码允许您生成所需的所有图像,而不仅仅是 3 次 3。我有一个名为cryptopunks的文件夹,其中包含许多名为#.png的图像(例如1.png 、...、 34.png , ...)。 只需更改row_countcol_count变量值。 例如,对于row_count=6col_count=8你会得到: 在此处输入图像描述

如果您的图像文件没有上述命名模式(即,只是随机名称),只需将第一行替换为以下行:

import os
from pathlib import Path

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import matplotlib.image as mpimg

for root, _, filenames in os.walk("cryptopunks/"):
    cryptopunks = [
        mpimg.imread(Path(root, filename)) for filename in filenames
    ]

row_count = 3
col_count = 3

# Here same code as above.

(我已经从Kaggle下载了 CryptoPunks 数据集。)

Matplotlib 提供了一个名为 subplot 的函数,我想这就是您要寻找的!

plt.subplot(9,1) 是我猜的语法。

然后配置你的情节

对于 9 个图像网格,您可能需要三行和三列。 可能最简单的方法是这样的:

import matplotlib.pyplot as plt

fig, axes = plt.subplots(nrows=3, ncols=3)
axes = axes.flatten() # Flattens the array so you can access individual axes

for ax in axes:
    # Do stuff with your individual axes here
    plt.show() # This call is here just for example, prob. better call outside of the loop

它使用plt.tight_layout()输出以下轴配置:

3x3 图像网格

您可能还对matplotlib 镶嵌功能或gridspec one感兴趣。 希望这可以帮助。

编辑:这是一个解决方案,它用它的数字注释每个地块,这样你就可以看到去哪里了:

import matplotlib.pyplot as plt

fig, axes = plt.subplots(nrows=3, ncols=3)
axes = axes.flatten()

for idx, ax in enumerate(axes):
    ax.annotate(f"{idx+1}", xy=(0.5, 0.5), xytext=(0.5, 0.5))

带注释

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM