简体   繁体   English

matplotlib 子图没有间距,限制图形大小和tight_layout()

[英]matplotlib subplots with no spacing, restricted figure size and tight_layout()

I am trying to make a plot which has two subplots which share the x axis and shall have no space in between them.我正在尝试制作一个 plot ,它有两个共享 x 轴的子图,它们之间不应有空格。 I followd the Create adjacent subplots example from the matplotlib gallery.我按照 matplotlib 库中的创建相邻子图示例进行操作。 However, my plot needs to have a fixed size and this makes everything complicated.然而,我的 plot 需要有一个固定的尺寸,这让一切变得复杂。 If I just follow the example and add a fixed size figure size, then the labels are cut off.如果我只是按照示例添加一个固定大小的图形大小,那么标签就会被切断。 If I include the labels by using tight_layout , then the plots are spaced.如果我使用tight_layout包含标签,那么这些图是间隔的。 How to fix this?如何解决这个问题? Also, the title should be closer to the legend.此外,标题应该更接近图例。 Any help is much appreciated!任何帮助深表感谢!

Example program, comment out tight_layout to see the difference.示例程序,注释掉tight_layout看看有什么不同。

import numpy as np                                                                                             
import matplotlib.pyplot as plt                                                                                
                                                                                                               
x_min = -2*np.pi                                                                                               
x_max = 2*np.pi                                                                                                
resolution = 101                                                                                               
x_vals = np.linspace(x_min, x_max, resolution)                                                                 
y_upper = np.cos(x_vals)                                                                                       
y_lower = -np.cos(x_vals)                                                                                      
data3 = np.sin(x_vals)                                                                                         
                                                                                                               
fig = plt.figure(figsize=(80/25.4, 80/25.4))  # figsize is needed for later usage of the plot                  
ax = fig.subplots(2, 1, sharex=True)                                                                           
fig.subplots_adjust(hspace=0)                                                                                  
                                                                                                               
ax[0].plot(x_vals, y_upper, label="data 1")                                                                    
ax[0].plot(x_vals, y_lower, label="data 2")                                                                    
                                                                                                               
ax[1].set_xlim([x_min,x_max])                                                                                  
ax[0].set_ylim([-1.6,1.6])                                                                                     
ax[1].set_ylim([-1.3,1.3])                                                                                     
                                                                                                               
ax[1].plot(x_vals, data3, ls='-', label="data 3", color='C2')                                                  
                                                                                                               
ax[1].set_xlabel("xaxis")                                                                                      
ax[0].set_ylabel("yaxis 1")                                                                                    
ax[1].set_ylabel("yaxis 2")                                                                                    
ax[0].legend(bbox_to_anchor=(0, 1.02, 1., 0.102), loc='lower left', ncol=2, mode="expand", borderaxespad=0)    
                                                                                                               
fig.suptitle("Title")                                                                                          
fig.tight_layout()  # comment this out to see the difference                                                   
# fig.savefig('figure.png')                                                                                    
plt.show()

You need to use a GridSpec instead of subplots_adjust() , that way tight_layout() will know that you want zero-space and it keep it that way.您需要使用GridSpec而不是 subplots_adjust subplots_adjust() ,这样, tight_layout()就会知道您想要零空间并保持这种状态。

In fact, you are already creating a GridSpec when you use fig.subplots() , so you just need to pass some extra parameter in gridspec_kw=事实上,当您使用fig.subplots()时,您已经创建了一个GridSpec ,因此您只需要在gridspec_kw=中传递一些额外的参数

x_min = -2*np.pi                                                                                               
x_max = 2*np.pi                                                                                                
resolution = 101                                                                                               
x_vals = np.linspace(x_min, x_max, resolution)                                                                 
y_upper = np.cos(x_vals)                                                                                       
y_lower = -np.cos(x_vals)                                                                                      
data3 = np.sin(x_vals)                                                                                         
                                                                                                               
fig = plt.figure(figsize=(80/25.4, 80/25.4))  # figsize is needed for later usage of the plot             
#
# This is the line that changes. Instruct the gridspec to have zero vertical pad
#     
ax = fig.subplots(2, 1, sharex=True, gridspec_kw=dict(hspace=0))                                                                           
                                                                                          
ax[0].plot(x_vals, y_upper, label="data 1")                                                                    
ax[0].plot(x_vals, y_lower, label="data 2")                                                                    
                                                                                                               
ax[1].set_xlim([x_min,x_max])                                                                                  
ax[0].set_ylim([-1.6,1.6])                                                                                     
ax[1].set_ylim([-1.3,1.3])                                                                                     
                                                                                                               
ax[1].plot(x_vals, data3, ls='-', label="data 3", color='C2')                                                  
                                                                                                               
ax[1].set_xlabel("xaxis")                                                                                      
ax[0].set_ylabel("yaxis 1")                                                                                    
ax[1].set_ylabel("yaxis 2")                                                                                    
ax[0].legend(bbox_to_anchor=(0, 1.02, 1., 0.102), loc='lower left', ncol=2, mode="expand", borderaxespad=0)    
                                                                                                               
fig.suptitle("Title")                                                                                          
fig.tight_layout()  # Now tight_layout does not add padding between axes
# fig.savefig('figure.png')                                                                                    
plt.show()

在此处输入图像描述

It can be frustrating to get precise results with subplots - using gridspec ( https://matplotlib.org/3.3.3/tutorials/intermediate/gridspec.html ) will give your greater precision.使用子图获得精确结果可能会令人沮丧 - 使用 gridspec ( https://matplotlib.org/3.3.3/tutorials/intermediate/gridspec.html ) 将提供更高的精度。

However, given where you are, I think you can get what you want with this:但是,考虑到您所在的位置,我认为您可以通过以下方式获得所需的东西:

import matplotlib.pyplot as plt                                                                                
                                                                                                               
x_min = -2*np.pi                                                                                               
x_max = 2*np.pi                                                                                                
resolution = 101                                                                                               
x_vals = np.linspace(x_min, x_max, resolution)                                                                 
y_upper = np.cos(x_vals)                                                                                       
y_lower = -np.cos(x_vals)                                                                                      
data3 = np.sin(x_vals)                                                                                         
                                                                                                               
fig = plt.figure(figsize=(80/25.4, 80/25.4))  # figsize is needed for later usage of the plot                  
ax = fig.subplots(3, 1, sharex=True)                                                                           
fig.subplots_adjust(hspace=0)                                                                                  
ax[0].text(0,0.5,"Title", ha='center')
ax[0].axis("off")
ax[1].plot(x_vals, y_upper, label="data 1")                                                                    
ax[1].plot(x_vals, y_lower, label="data 2")                                                                    
                                                                                                               
ax[2].set_xlim([x_min,x_max])                                                                                  
ax[1].set_ylim([-1.6,1.6])                                                                                     
ax[2].set_ylim([-1.3,1.3])                                                                                     
                                                                                                               
ax[2].plot(x_vals, data3, ls='-', label="data 3", color='C2')                                                  
                                                                                                               
ax[2].set_xlabel("xaxis")                                                                                      
ax[1].set_ylabel("yaxis 1")                                                                                    
ax[2].set_ylabel("yaxis 2")                                                                                    
ax[1].legend(bbox_to_anchor=(0, 1.02, 1., 0.102), loc='lower left', ncol=2, mode="expand", borderaxespad=0)    

#fig.tight_layout()  # comment this out to see the difference                                                   
# fig.savefig('figure.png')                                                                                    
plt.show()

在此处输入图像描述

Of course, gridspec is the correct approach, and if you are in early phases of the script writing, you should adapt this .当然, gridspec是正确的方法,如果你处于脚本编写的早期阶段,你应该适应这个. However, if you want an easy fix, you could also move fig.subplots_adjust() :但是,如果你想要一个简单的修复,你也可以移动fig.subplots_adjust()

#...
fig.suptitle("Title")                                                                                          
fig.tight_layout()    
fig.subplots_adjust(hspace=0)                                            
# fig.savefig('figure.png')                                                                                    
plt.show()

Saved image:保存的图像:
在此处输入图像描述

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM