简体   繁体   中英

Precise control over subplot locations in matplotlib

I am currently producing a figure for a paper, which looks like this:

在此处输入图片说明

The above is pretty close to how I want it to look, but I have a strong feeling that I'm not doing this the "right way", since it was really fiddly to produce, and my code is full of all sorts of magic numbers where I fine-tuned the positioning by hand. Thus my question is, what is the right way to produce a plot like this?

Here are the important features of this plot that made it hard to produce:

  • The aspect ratios of the three subplots are fixed by the data, but the images are not all at the same resolution.

  • I wanted all three plots to take up the full height of the figure

  • I wanted (a) and (b) to be close together since they share their y axis, while (c) is further away

  • Ideally, I would like the top of the top colour bar to exactly match the top of the three images, and similarly with the bottom of the lower colour bar. (In fact they aren't quite aligned, because I did this by guessing numbers and re-compiling the image.)

In producing this figure, I first tried using GridSpec , but I wasn't able to control the relative spacing between the three main subplots. I then tried ImageGrid, which is part of the AxisGrid toolkit , but the differing resolutions between the three images caused that to behave strangely. Delving deeper into AxesGrid, I was able to position the three main subplots using the append_axes function, but I still had to position the three colourbars by hand. (I created the colourbars manually.)

I'd rather not post my existing code, because it's a horrible collection of hacks and magic numbers. Rather my question is, is there any way in MatPlotLib to just specify the logical layout of the figure (ie the content of the bullet points above) and have the layout calculated for me automatically?

Here is a possible solution. You'd start with the figure width (which makes sense when preparing a paper) and calculate your way through, using the aspects of the figures, some arbitrary spacings between the subplots and the margins. The formulas are similar to the ones I used in this answer . And the unequal aspects are taken care of by GridSpec 's width_ratios argument. You then end up with a figure height such that the subplots' are equal in height.

So you cannot avoid typing in some numbers, but they are not "magic". All are related to acessible quatities like fraction of figure size or fraction of mean subplots size. Since the system is closed, changing any number will simply produce a different figure height, but will not destroy the layout.

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np; np.random.seed(42)

imgs = []
shapes = [(550,200), ( 550,205), (1100,274) ]
for shape in shapes:
    imgs.append(np.random.random(shape))

# calculate inverse aspect(width/height) for all images
inva = np.array([ img.shape[1]/float(img.shape[0]) for img in imgs])
# set width of empty column used to stretch layout
emptycol = 0.02
r = np.array([inva[0],inva[1], emptycol, inva[2], 3*emptycol, emptycol])
# set a figure width in inch
figw = 8
# border, can be set independently of all other quantities
left = 0.1; right=1-left
bottom=0.1; top=1-bottom
# wspace (=average relative space between subplots)
wspace = 0.1
#calculate scale
s = figw*(right-left)/(len(r)+(len(r)-1)*wspace) 
# mean aspect
masp = len(r)/np.sum(r)
#calculate figheight
figh = s*masp/float(top-bottom)


gs = gridspec.GridSpec(3,len(r), width_ratios=r)

fig = plt.figure(figsize=(figw,figh))
plt.subplots_adjust(left, bottom, right, top, wspace)

ax1 = plt.subplot(gs[:,0])
ax2 = plt.subplot(gs[:,1])
ax2.set_yticks([])

ax3 = plt.subplot(gs[:,3])
ax3.yaxis.tick_right()
ax3.yaxis.set_label_position("right")

cax1 = plt.subplot(gs[0,5])
cax2 = plt.subplot(gs[1,5])
cax3 = plt.subplot(gs[2,5])


im1 = ax1.imshow(imgs[0], cmap="viridis")
im2 = ax2.imshow(imgs[1], cmap="plasma")
im3 = ax3.imshow(imgs[2], cmap="RdBu")

fig.colorbar(im1, ax=ax1, cax=cax1)
fig.colorbar(im2, ax=ax2, cax=cax2)
fig.colorbar(im3, ax=ax3, cax=cax3)

ax1.set_title("image title")
ax1.set_xlabel("xlabel")
ax1.set_ylabel("ylabel")

plt.show()

在此处输入图片说明

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM