[英]Creating an image grid of tensors out of a batch using matplotlib/pytorch
我正在尝试从一批张量中创建一个图像网格(例如 3 x 3),这些张量将在下一步通过数据加载器输入到 GAN 中。 使用下面的代码,我能够将张量转换为显示在正确位置的网格中的图像。 问题是,它们都显示在一个单独的网格中,如下所示:图 1图 2图 5 。 如何将它们放在一个网格中,然后只返回一个包含所有 9 个图像的图形? 也许我把它弄得太复杂了。 :D 最后,real_samples 中的张量必须被转换并放入网格中。
real_samples = next(iter(train_loader))
for i in range(9):
plt.figure(figsize=(9, 9))
plt.subplot(330 + i + 1)
plt.imshow(np.transpose(vutils.make_grid(real_samples[i].to(device)
[:40], padding=1, normalize=True).cpu(),(1,2,0)))
plt.show()
以下是如何使用 matplotlib XD 显示可变数量的精彩 CryptoPunk:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import matplotlib.image as mpimg
row_count = 3
col_count = 3
cryptopunks = [
mpimg.imread(f"cryptopunks/{i}.png") for i in range(row_count * col_count)
]
fig = plt.figure(figsize=(8., 8.))
grid = ImageGrid(fig, 111, nrows_ncols=(row_count, col_count), axes_pad=0.1)
for ax, im in zip(grid, cryptopunks):
ax.imshow(im)
plt.show()
请注意,代码允许您生成所需的所有图像,而不仅仅是 3 次 3。我有一个名为cryptopunks
的文件夹,其中包含许多名为#.png
的图像(例如1.png
、...、 34.png
, ...)。 只需更改row_count
和col_count
变量值。 例如,对于row_count=6
和col_count=8
你会得到:
如果您的图像文件没有上述命名模式(即,只是随机名称),只需将第一行替换为以下行:
import os
from pathlib import Path
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import matplotlib.image as mpimg
for root, _, filenames in os.walk("cryptopunks/"):
cryptopunks = [
mpimg.imread(Path(root, filename)) for filename in filenames
]
row_count = 3
col_count = 3
# Here same code as above.
(我已经从Kaggle下载了 CryptoPunks 数据集。)
Matplotlib 提供了一个名为 subplot 的函数,我想这就是您要寻找的!
plt.subplot(9,1) 是我猜的语法。
然后配置你的情节
对于 9 个图像网格,您可能需要三行和三列。 可能最简单的方法是这样的:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(nrows=3, ncols=3)
axes = axes.flatten() # Flattens the array so you can access individual axes
for ax in axes:
# Do stuff with your individual axes here
plt.show() # This call is here just for example, prob. better call outside of the loop
它使用plt.tight_layout()
输出以下轴配置:
您可能还对matplotlib 镶嵌功能或gridspec one感兴趣。 希望这可以帮助。
编辑:这是一个解决方案,它用它的数字注释每个地块,这样你就可以看到去哪里了:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(nrows=3, ncols=3)
axes = axes.flatten()
for idx, ax in enumerate(axes):
ax.annotate(f"{idx+1}", xy=(0.5, 0.5), xytext=(0.5, 0.5))
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.