簡體   English   中英

如何刪除 matplotlib.pyplot 中子圖之間的空格?

[英]How to remove the space between subplots in matplotlib.pyplot?

我正在做一個項目,我需要將 10 行 3 列的 plot 網格放在一起。 雖然我已經能夠制作情節並安排子情節,但我無法制作出沒有空白的漂亮 plot,例如下面來自gridspec 文檔的空白。 沒有空白的圖像 .

我嘗試了以下帖子,但仍然無法像示例圖像中那樣完全刪除空白區域。 有人可以給我一些指導嗎? 謝謝!

這是我的形象: 我的形象

下面是我的代碼。 完整的腳本在 GitHub 上 注意:images_2 和 images_fool 都是 numpy arrays 形狀為 (1032, 10) 的扁平圖像,而 delta 是形狀為 (28, 28) 的圖像數組。

def plot_im(array=None, ind=0):
    """A function to plot the image given a images matrix, type of the matrix: \
    either original or fool, and the order of images in the matrix"""
    img_reshaped = array[ind, :].reshape((28, 28))
    imgplot = plt.imshow(img_reshaped)

# Output as a grid of 10 rows and 3 cols with first column being original, second being
# delta and third column being adversaril
nrow = 10
ncol = 3
n = 0

from matplotlib import gridspec
fig = plt.figure(figsize=(30, 30)) 
gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1]) 

for row in range(nrow):
    for col in range(ncol):
        plt.subplot(gs[n])
        if col == 0:
            #plt.subplot(nrow, ncol, n)
            plot_im(array=images_2, ind=row)
        elif col == 1:
            #plt.subplot(nrow, ncol, n)
            plt.imshow(w_delta)
        else:
            #plt.subplot(nrow, ncol, n)
            plot_im(array=images_fool, ind=row)
        n += 1

plt.tight_layout()
#plt.show()
plt.savefig('grid_figure.pdf')

開頭的注意事項:如果您想完全控制間距,請避免使用plt.tight_layout()因為它會嘗試將圖中的圖plt.tight_layout()均勻且良好的分布。 這通常很好並產生令人愉快的結果,但可以隨意調整間距。

您從 Matplotlib 示例庫中引用的 GridSpec 示例運行良好的原因是因為子圖的方面沒有預定義。 也就是說,子圖將簡單地在網格上擴展,並使設置的間距(在這種情況下wspace=0.0, hspace=0.0 )與圖形大小無關。

與此相反,您使用imshow繪制圖像並且圖像的縱橫比默認設置為相等(相當於ax.set_aspect("equal") )。 也就是說,您當然可以將set_aspect("auto")放在每個繪圖中(並另外添加wspace=0.0, hspace=0.0作為 GridSpec 的參數,如畫廊示例中所示),這將生成一個沒有間距的繪圖。

然而,當使用圖像時,保持相等的縱橫比是很有意義的,這樣每個像素都和高一樣寬,方形陣列顯示為方形圖像。
然后您需要做的是調整圖像大小和圖形邊距以獲得預期結果。 figsize參數是以英寸為單位的圖形(寬度,高度),這里可以使用兩個數字的比率。 並且可以手動調整子圖參數wspace, hspace, top, bottom, left以獲得所需的結果。 下面是一個例子:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec

nrow = 10
ncol = 3

fig = plt.figure(figsize=(4, 10)) 

gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1],
         wspace=0.0, hspace=0.0, top=0.95, bottom=0.05, left=0.17, right=0.845) 

for i in range(10):
    for j in range(3):
        im = np.random.rand(28,28)
        ax= plt.subplot(gs[i,j])
        ax.imshow(im)
        ax.set_xticklabels([])
        ax.set_yticklabels([])

#plt.tight_layout() # do not use this!!
plt.show()

在此處輸入圖片說明

編輯:
不必手動調整參數當然是可取的。 所以可以根據行數和列數計算出一些最優值。

nrow = 7
ncol = 7

fig = plt.figure(figsize=(ncol+1, nrow+1)) 

gs = gridspec.GridSpec(nrow, ncol,
         wspace=0.0, hspace=0.0, 
         top=1.-0.5/(nrow+1), bottom=0.5/(nrow+1), 
         left=0.5/(ncol+1), right=1-0.5/(ncol+1)) 

for i in range(nrow):
    for j in range(ncol):
        im = np.random.rand(28,28)
        ax= plt.subplot(gs[i,j])
        ax.imshow(im)
        ax.set_xticklabels([])
        ax.set_yticklabels([])

plt.show()

嘗試將此行添加到您的代碼中:

fig.subplots_adjust(wspace=0, hspace=0)

對於每個軸對象集:

ax.set_xticklabels([])
ax.set_yticklabels([])

按照ImportanceOfBeingErnest的回答,但如果您想使用plt.subplots及其功能:

fig, axes = plt.subplots(
    nrow, ncol,
    gridspec_kw=dict(wspace=0.0, hspace=0.0,
                     top=1. - 0.5 / (nrow + 1), bottom=0.5 / (nrow + 1),
                     left=0.5 / (ncol + 1), right=1 - 0.5 / (ncol + 1)),
    figsize=(ncol + 1, nrow + 1),
    sharey='row', sharex='col', #  optionally
)

如果您正在使用 matplotlib.pyplot.subplots,您可以使用軸 arrays 顯示任意數量的圖像。您可以通過對 matplotlib.pyplot.subplots 配置進行一些調整來刪除圖像之間的空間。

import matplotlib.pyplot as plt

def show_dataset_overview(self, img_list):
"""show each image in img_list without space"""
    img_number = len(img_list)
    img_number_at_a_row = 3
    row_number = int(img_number /img_number_at_a_row) 
    fig_size = (15*(img_number_at_a_row/row_number), 15)
    _, axs = plt.subplots(row_number, 
                          img_number_at_a_row, 
                          figsize=fig_size , 
                          gridspec_kw=dict(
                                       top = 1, bottom = 0, right = 1, left = 0, 
                                       hspace = 0, wspace = 0
                                       )
                         )
    axs = axs.flatten()

    for i in range(img_number):
        axs[i].imshow(img_list[i])
        axs[i].set_xticks([])
        axs[i].set_yticks([])

由於我們首先在此處創建子圖,因此我們可以使用 gridspec_kw 參數 ( source ) 為 grid_spec 提供一些參數。 在這些參數中,“top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0”參數將防止圖像間距。 要查看其他參數,請訪問此處

在設置上面的 figure_size 時,我通常使用像 (30,15) 這樣的圖形大小。 我對此進行了概括並將其添加到代碼中。 如果您願意,可以在此處輸入手動尺寸。

這是使用ImageGrid class 的另一種簡單方法(改編自此答案)。

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

nrow = 5
ncol = 3
fig = plt.figure(figsize=(4, 10))
grid = ImageGrid(fig, 
                 111, # as in plt.subplot(111)
                 nrows_ncols=(nrow,ncol),
                 axes_pad=0,
                 share_all=True,)

for row in grid.axes_column:
    for ax in row:
        im = np.random.rand(28,28)
        ax.imshow(im)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

在此處輸入圖像描述

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM