简体   繁体   English

如何删除 matplotlib.pyplot 中子图之间的空格?

[英]How to remove the space between subplots in matplotlib.pyplot?

I am working on a project in which I need to put together a plot grid of 10 rows and 3 columns.我正在做一个项目,我需要将 10 行 3 列的 plot 网格放在一起。 Although I have been able to make the plots and arrange the subplots, I was not able to produce a nice plot without white space such as this one below from gridspec documentatation .虽然我已经能够制作情节并安排子情节,但我无法制作出没有空白的漂亮 plot,例如下面来自gridspec 文档的空白。 没有空白的图像 . .

I tried the following posts, but still not able to completely remove the white space as in the example image.我尝试了以下帖子,但仍然无法像示例图像中那样完全删除空白区域。 Can someone please give me some guidance?有人可以给我一些指导吗? Thanks!谢谢!

Here's my image:这是我的形象: 我的形象

Below is my code.下面是我的代码。 The full script is here on GitHub . 完整的脚本在 GitHub 上 Note: images_2 and images_fool are both numpy arrays of flattened images with shape (1032, 10), while delta is an image array of shape (28, 28).注意:images_2 和 images_fool 都是 numpy arrays 形状为 (1032, 10) 的扁平图像,而 delta 是形状为 (28, 28) 的图像数组。

def plot_im(array=None, ind=0):
    """A function to plot the image given a images matrix, type of the matrix: \
    either original or fool, and the order of images in the matrix"""
    img_reshaped = array[ind, :].reshape((28, 28))
    imgplot = plt.imshow(img_reshaped)

# Output as a grid of 10 rows and 3 cols with first column being original, second being
# delta and third column being adversaril
nrow = 10
ncol = 3
n = 0

from matplotlib import gridspec
fig = plt.figure(figsize=(30, 30)) 
gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1]) 

for row in range(nrow):
    for col in range(ncol):
        plt.subplot(gs[n])
        if col == 0:
            #plt.subplot(nrow, ncol, n)
            plot_im(array=images_2, ind=row)
        elif col == 1:
            #plt.subplot(nrow, ncol, n)
            plt.imshow(w_delta)
        else:
            #plt.subplot(nrow, ncol, n)
            plot_im(array=images_fool, ind=row)
        n += 1

plt.tight_layout()
#plt.show()
plt.savefig('grid_figure.pdf')

A note at the beginning: If you want to have full control over spacing, avoid using plt.tight_layout() as it will try to arange the plots in your figure to be equally and nicely distributed.开头的注意事项:如果您想完全控制间距,请避免使用plt.tight_layout()因为它会尝试将图中的图plt.tight_layout()均匀且良好的分布。 This is mostly fine and produces pleasant results, but adjusts the spacing at its will.这通常很好并产生令人愉快的结果,但可以随意调整间距。

The reason the GridSpec example you're quoting from the Matplotlib example gallery works so well is because the subplots' aspect is not predefined.您从 Matplotlib 示例库中引用的 GridSpec 示例运行良好的原因是因为子图的方面没有预定义。 That is, the subplots will simply expand on the grid and leave the set spacing (in this case wspace=0.0, hspace=0.0 ) independent of the figure size.也就是说,子图将简单地在网格上扩展,并使设置的间距(在这种情况下wspace=0.0, hspace=0.0 )与图形大小无关。

In contrast to that you are plotting images with imshow and the image's aspect is set equal by default (equivalent to ax.set_aspect("equal") ).与此相反,您使用imshow绘制图像并且图像的纵横比默认设置为相等(相当于ax.set_aspect("equal") )。 That said, you could of course put set_aspect("auto") to every plot (and additionally add wspace=0.0, hspace=0.0 as arguments to GridSpec as in the gallery example), which would produce a plot without spacings.也就是说,您当然可以将set_aspect("auto")放在每个绘图中(并另外添加wspace=0.0, hspace=0.0作为 GridSpec 的参数,如画廊示例中所示),这将生成一个没有间距的绘图。

However when using images it makes a lot of sense to keep an equal aspect ratio such that every pixel is as wide as high and a square array is shown as a square image.然而,当使用图像时,保持相等的纵横比是很有意义的,这样每个像素都和高一样宽,方形阵列显示为方形图像。
What you will need to do then is to play with the image size and the figure margins to obtain the expected result.然后您需要做的是调整图像大小和图形边距以获得预期结果。 The figsize argument to figure is the figure (width, height) in inch and here the ratio of the two numbers can be played with. figsize参数是以英寸为单位的图形(宽度,高度),这里可以使用两个数字的比率。 And the subplot parameters wspace, hspace, top, bottom, left can be manually adjusted to give the desired result.并且可以手动调整子图参数wspace, hspace, top, bottom, left以获得所需的结果。 Below is an example:下面是一个例子:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec

nrow = 10
ncol = 3

fig = plt.figure(figsize=(4, 10)) 

gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1],
         wspace=0.0, hspace=0.0, top=0.95, bottom=0.05, left=0.17, right=0.845) 

for i in range(10):
    for j in range(3):
        im = np.random.rand(28,28)
        ax= plt.subplot(gs[i,j])
        ax.imshow(im)
        ax.set_xticklabels([])
        ax.set_yticklabels([])

#plt.tight_layout() # do not use this!!
plt.show()

在此处输入图片说明

Edit:编辑:
It is of course desireable not having to tweak the parameters manually.不必手动调整参数当然是可取的。 So one could calculate some optimal ones according to the number of rows and columns.所以可以根据行数和列数计算出一些最优值。

nrow = 7
ncol = 7

fig = plt.figure(figsize=(ncol+1, nrow+1)) 

gs = gridspec.GridSpec(nrow, ncol,
         wspace=0.0, hspace=0.0, 
         top=1.-0.5/(nrow+1), bottom=0.5/(nrow+1), 
         left=0.5/(ncol+1), right=1-0.5/(ncol+1)) 

for i in range(nrow):
    for j in range(ncol):
        im = np.random.rand(28,28)
        ax= plt.subplot(gs[i,j])
        ax.imshow(im)
        ax.set_xticklabels([])
        ax.set_yticklabels([])

plt.show()

Try to add to your code this line:尝试将此行添加到您的代码中:

fig.subplots_adjust(wspace=0, hspace=0)

And for every an axis object set:对于每个轴对象集:

ax.set_xticklabels([])
ax.set_yticklabels([])

Following the answer by ImportanceOfBeingErnest, but if you want to use plt.subplots and its features:按照ImportanceOfBeingErnest的回答,但如果您想使用plt.subplots及其功能:

fig, axes = plt.subplots(
    nrow, ncol,
    gridspec_kw=dict(wspace=0.0, hspace=0.0,
                     top=1. - 0.5 / (nrow + 1), bottom=0.5 / (nrow + 1),
                     left=0.5 / (ncol + 1), right=1 - 0.5 / (ncol + 1)),
    figsize=(ncol + 1, nrow + 1),
    sharey='row', sharex='col', #  optionally
)

If you are using matplotlib.pyplot.subplots you can display as many images as you want using Axes arrays. You can remove the spaces between images by making some adjustments to the matplotlib.pyplot.subplots configuration.如果您正在使用 matplotlib.pyplot.subplots,您可以使用轴 arrays 显示任意数量的图像。您可以通过对 matplotlib.pyplot.subplots 配置进行一些调整来删除图像之间的空间。

import matplotlib.pyplot as plt

def show_dataset_overview(self, img_list):
"""show each image in img_list without space"""
    img_number = len(img_list)
    img_number_at_a_row = 3
    row_number = int(img_number /img_number_at_a_row) 
    fig_size = (15*(img_number_at_a_row/row_number), 15)
    _, axs = plt.subplots(row_number, 
                          img_number_at_a_row, 
                          figsize=fig_size , 
                          gridspec_kw=dict(
                                       top = 1, bottom = 0, right = 1, left = 0, 
                                       hspace = 0, wspace = 0
                                       )
                         )
    axs = axs.flatten()

    for i in range(img_number):
        axs[i].imshow(img_list[i])
        axs[i].set_xticks([])
        axs[i].set_yticks([])

Since we create subplots here first, we can give some parameters for grid_spec using the gridspec_kw parameter( source ).由于我们首先在此处创建子图,因此我们可以使用 gridspec_kw 参数 ( source ) 为 grid_spec 提供一些参数。 Among these parameters are the "top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0" parameters that will prevent inter-image spacing.在这些参数中,“top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0”参数将防止图像间距。 To see other parameters, please visit here .要查看其他参数,请访问此处

I usually use a figure size like (30,15) when setting the figure_size above.在设置上面的 figure_size 时,我通常使用像 (30,15) 这样的图形大小。 I generalized this a bit and added it to the code.我对此进行了概括并将其添加到代码中。 If you wish, you can enter a manual size here.如果您愿意,可以在此处输入手动尺寸。

Here's another simple approach using the ImageGrid class (adapted from this answer ).这是使用ImageGrid class 的另一种简单方法(改编自此答案)。

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

nrow = 5
ncol = 3
fig = plt.figure(figsize=(4, 10))
grid = ImageGrid(fig, 
                 111, # as in plt.subplot(111)
                 nrows_ncols=(nrow,ncol),
                 axes_pad=0,
                 share_all=True,)

for row in grid.axes_column:
    for ax in row:
        im = np.random.rand(28,28)
        ax.imshow(im)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

在此处输入图像描述

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM