简体   繁体   English

如何在多个子图中 plot

[英]How to plot in multiple subplots

I am a little confused about how this code works:我对这段代码的工作方式有点困惑:

fig, axes = plt.subplots(nrows=2, ncols=2)
plt.show()

How does the fig, axes work in this case?在这种情况下,无花果轴如何工作? What does it do?它有什么作用?

Also why wouldn't this work to do the same thing:还有为什么这不能做同样的事情:

fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)

There are several ways to do it.有几种方法可以做到这一点。 The subplots method creates the figure along with the subplots that are then stored in the ax array. subplots方法创建图形以及随后存储在ax数组中的子图。 For example:例如:

import matplotlib.pyplot as plt

x = range(10)
y = range(10)

fig, ax = plt.subplots(nrows=2, ncols=2)

for row in ax:
    for col in row:
        col.plot(x, y)

plt.show()

在此处输入图像描述

However, something like this will also work, it's not so "clean" though since you are creating a figure with subplots and then add on top of them:然而,这样的事情也可以工作,但它不是那么“干净”,因为你正在创建一个带有子图的图形,然后在它们之上添加:

fig = plt.figure()

plt.subplot(2, 2, 1)
plt.plot(x, y)

plt.subplot(2, 2, 2)
plt.plot(x, y)

plt.subplot(2, 2, 3)
plt.plot(x, y)

plt.subplot(2, 2, 4)
plt.plot(x, y)

plt.show()

在此处输入图像描述

import matplotlib.pyplot as plt

fig, ax = plt.subplots(2, 2)

ax[0, 0].plot(range(10), 'r') #row=0, col=0
ax[1, 0].plot(range(10), 'b') #row=1, col=0
ax[0, 1].plot(range(10), 'g') #row=0, col=1
ax[1, 1].plot(range(10), 'k') #row=1, col=1
plt.show()

在此处输入图像描述

  • You can also unpack the axes in the subplots call您还可以在 subplots 调用中解压缩轴

  • And set whether you want to share the x and y axes between the subplots并设置是否要在子图之间共享 x 和 y 轴

Like this:像这样:

import matplotlib.pyplot as plt
# fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
ax1, ax2, ax3, ax4 = axes.flatten()

ax1.plot(range(10), 'r')
ax2.plot(range(10), 'b')
ax3.plot(range(10), 'g')
ax4.plot(range(10), 'k')
plt.show()

2 x 2 绘图

You might be interested in the fact that as of matplotlib version 2.1 the second code from the question works fine as well.您可能对从 matplotlib 版本 2.1 开始该问题的第二个代码也可以正常工作这一事实感兴趣。

From the change log :更改日志

Figure class now has subplots method The Figure class now has a subplots() method which behaves the same as pyplot.subplots() but on an existing figure. Figure 类现在有 subplots 方法 Figure 类现在有一个 subplots() 方法,其行为与 pyplot.subplots() 相同,但在现有图形上。

Example:例子:

import matplotlib.pyplot as plt

fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)

plt.show()

Read the documentation: matplotlib.pyplot.subplots阅读文档: matplotlib.pyplot.subplots

pyplot.subplots() returns a tuple fig, ax which is unpacked in two variables using the notation pyplot.subplots()返回一个元组fig, ax使用符号解包在两个变量中

fig, axes = plt.subplots(nrows=2, ncols=2)

The code:编码:

fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)

does not work because subplots() is a function in pyplot not a member of the object Figure .不起作用,因为 subplots subplots()pyplot中的一个函数,而不是对象Figure的成员。

Iterating through all subplots sequentially:依次遍历所有子图:

fig, axes = plt.subplots(nrows, ncols)

for ax in axes.flatten():
    ax.plot(x,y)

Accessing a specific index:访问特定索引:

for row in range(nrows):
    for col in range(ncols):
        axes[row,col].plot(x[row], y[col])

Subplots with pandas带有熊猫的子图

  • This answer is for subplots with pandas , which uses matplotlib as the default plotting backend.这个答案适用于带有pandas的子图,它使用matplotlib作为默认的绘图后端。
  • Here are four options to create subplots starting with a pandas.DataFrame以下是创建以pandas.DataFrame开头的子图的四个选项
    • Implementation 1. and 2. are for the data in a wide format, creating subplots for each column.实现 1. 和 2. 用于宽格式的数据,为每列创建子图。
    • Implementation 3. and 4. are for data in a long format, creating subplots for each unique value in a column.实现 3. 和 4. 用于长格式数据,为列中的每个唯一值创建子图。
  • Tested in python 3.8.11 , pandas 1.3.2 , matplotlib 3.4.3 , seaborn 0.11.2python 3.8.11pandas 1.3.2matplotlib 3.4.3seaborn 0.11.2

Imports and Data导入和数据

import seaborn as sns  # data only
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# wide dataframe
df = sns.load_dataset('planets').iloc[:, 2:5]

   orbital_period   mass  distance
0         269.300   7.10     77.40
1         874.774   2.21     56.95
2         763.000   2.60     19.84
3         326.030  19.40    110.62
4         516.220  10.50    119.47

# long dataframe
dfm = sns.load_dataset('planets').iloc[:, 2:5].melt()

         variable    value
0  orbital_period  269.300
1  orbital_period  874.774
2  orbital_period  763.000
3  orbital_period  326.030
4  orbital_period  516.220

1. subplots=True and layout , for each column 1. subplots=Truelayout ,对于每一列

  • Use the parameters subplots=True and layout=(rows, cols) in pandas.DataFrame.plotpandas.DataFrame.plot中使用参数subplots=Truelayout=(rows, cols)
  • This example uses kind='density' , but there are different options for kind , and this applies to them all.此示例使用kind='density' ,但kind有不同的选项,这适用于所有选项。 Without specifying kind , a line plot is the default.如果不指定kind ,则默认使用折线图。
  • ax is array of AxesSubplot returned by pandas.DataFrame.plot axAxesSubplot返回的pandas.DataFrame.plot数组
  • See How to get a Figure object , if needed.如果需要,请参阅如何获取Figure对象
axes = df.plot(kind='density', subplots=True, layout=(2, 2), sharex=False, figsize=(10, 6))

# extract the figure object; only used for tight_layout in this example
fig = axes[0][0].get_figure() 

# set the individual titles
for ax, title in zip(axes.ravel(), df.columns):
    ax.set_title(title)
fig.tight_layout()
plt.show()

2. plt.subplots , for each column 2. plt.subplots ,对于每一列

  • Create an array of Axes with matplotlib.pyplot.subplots and then pass axes[i, j] or axes[n] to the ax parameter.使用matplotlib.pyplot.subplots创建一个Axes数组,然后将axes[i, j]axes[n]传递给ax参数。
    • This option uses pandas.DataFrame.plot , but can use other axes level plot calls as a substitute (eg sns.kdeplot , plt.plot , etc.)此选项使用pandas.DataFrame.plot ,但可以使用其他axes级别绘图调用作为替代(例如sns.kdeplotplt.plot等)
    • It's easiest to collapse the subplot array of Axes into one dimension with .ravel or .flatten .使用.ravel.flattenAxes的子图数组折叠成一维是最简单的。 See .ravel vs .flatten .请参阅.ravel.flatten
    • Any variables applying to each axes , that need to be iterate through, are combined with .zip (eg cols , axes , colors , palette , etc.).适用于每个axes的任何变量,需要迭代,都与.zip组合(例如colsaxescolorspalette等)。 Each object must be the same length.每个对象的长度必须相同。
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6))  # define the figure and subplots
axes = axes.ravel()  # array to 1D
cols = df.columns  # create a list of dataframe columns to use
colors = ['tab:blue', 'tab:orange', 'tab:green']  # list of colors for each subplot, otherwise all subplots will be one color

for col, color, ax in zip(cols, colors, axes):
    df[col].plot(kind='density', ax=ax, color=color, label=col, title=col)
    ax.legend()
    
fig.delaxes(axes[3])  # delete the empty subplot
fig.tight_layout()
plt.show()

Result for 1. and 2. 1. 和 2 的结果。

在此处输入图像描述

3. plt.subplots , for each group in .groupby 3. plt.subplots ,对于.groupby中的每个组

  • This is similar to 2., except it zips color and axes to a .groupby object.这与 2. 类似,不同之处在于它将coloraxes压缩到.groupby对象。
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6))  # define the figure and subplots
axes = axes.ravel()  # array to 1D
dfg = dfm.groupby('variable')  # get data for each unique value in the first column
colors = ['tab:blue', 'tab:orange', 'tab:green']  # list of colors for each subplot, otherwise all subplots will be one color

for (group, data), color, ax in zip(dfg, colors, axes):
    data.plot(kind='density', ax=ax, color=color, title=group, legend=False)

fig.delaxes(axes[3])  # delete the empty subplot
fig.tight_layout()
plt.show()

在此处输入图像描述

4. seaborn figure-level plot 4. seaborn人物级剧情

  • Use a seaborn figure-level plot, and use the col or row parameter.使用seaborn图形级别图,并使用colrow参数。 seaborn is a high-level API for matplotlib . seabornmatplotlib的高级 API。 See seaborn: API reference请参阅seaborn:API 参考
p = sns.displot(data=dfm, kind='kde', col='variable', col_wrap=2, x='value', hue='variable',
                facet_kws={'sharey': False, 'sharex': False}, height=3.5, aspect=1.75)
sns.move_legend(p, "upper left", bbox_to_anchor=(.55, .45))

The other answers are great, this answer is a combination which might be useful.其他答案很好,这个答案是一个可能有用的组合。

import numpy as np
import matplotlib.pyplot as plt

# Optional: define x for all the sub-plots
x = np.linspace(0,2*np.pi,100)

# (1) Prepare the figure infrastructure 
fig, ax_array = plt.subplots(nrows=2, ncols=2)

# flatten the array of axes, which makes them easier to iterate through and assign
ax_array = ax_array.flatten()

# (2) Plot loop
for i, ax in enumerate(ax_array):
  ax.plot(x , np.sin(x + np.pi/2*i))
  #ax.set_title(f'plot {i}')

# Optional: main title
plt.suptitle('Plots')

代码结果:绘图

Summary概括

  1. Prepare the figure infrastructure准备图基础设施
    • Get ax_array, an array of the subplots获取 ax_array,子图的数组
    • Flatten the array in order to use it in one 'for loop'展平数组以便在一个“for循环”中使用它
  2. Plot loop绘图循环
    • Loop over the flattened ax_array to update the subplots循环平展的 ax_array 以更新子图
    • optional: use enumeration to track subplot number可选:使用枚举来跟踪子图号
  3. Once flattened, each ax_array can be individually indexed from 0 through nrows x ncols -1 (eg ax_array[0] , ax_array[1] , ax_array[2] , ax_array[3] ).一旦展平,每个ax_array可以从0nrows x ncols -1单独索引(例如ax_array[0]ax_array[1]ax_array[2]ax_array[3] )。

Convert the axes array to 1Daxes数组转换为 1D

  • Generating subplots with plt.subplots(nrows, ncols) , where both nrows and ncols is greater than 1, returns a nested array of <AxesSubplot:> objects.使用plt.subplots(nrows, ncols)生成子图,其中 nrows 和 ncols大于 1,返回<AxesSubplot:>对象的嵌套数组。
    • It's not necessary to flatten axes in cases where either nrows=1 or ncols=1 , because axes will already be 1 dimensional, which is a result of the default parameter squeeze=Truenrows=1ncols=1的情况下,没有必要展平axes ,因为axes已经是一维的,这是默认参数squeeze=True的结果
  • The easiest way to access the objects, is to convert the array to 1 dimension with .ravel() , .flatten() , or .flat .访问对象的最简单方法是使用.ravel().flatten().flat将数组转换为一维。
    • .ravel vs. .flatten .ravel.flatten
      • flatten always returns a copy. flatten总是返回一个副本。
      • ravel returns a view of the original array whenever possible. ravel尽可能返回原始数组的视图。
  • Once the array of axes is converted to 1-d, there are a number of ways to plot.axes数组转换为 1-d 后,有多种绘图方法。
  • This answer is relevant to seaborn axes-level plots, which have the ax= parameter (eg sns.barplot(..., ax=ax[0]) .该答案与具有ax=参数(例如sns.barplot(..., ax=ax[0])的 seaborn 轴级图相关。
import matplotlib.pyplot as plt
import numpy as np  # sample data only

# example of data
rads = np.arange(0, 2*np.pi, 0.01)
y_data = np.array([np.sin(t*rads) for t in range(1, 5)])
x_data = [rads, rads, rads, rads]

# Generate figure and its subplots
fig, axes = plt.subplots(nrows=2, ncols=2)

# axes before
array([[<AxesSubplot:>, <AxesSubplot:>],
       [<AxesSubplot:>, <AxesSubplot:>]], dtype=object)

# convert the array to 1 dimension
axes = axes.ravel()

# axes after
array([<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
      dtype=object)
  1. Iterate through the flattened array遍历展平的数组
    • If there are more subplots than data, this will result in IndexError: list index out of range如果子图比数据多,这将导致IndexError: list index out of range
      • Try option 3. instead, or select a subset of the axes (eg axes[:-2] )尝试选项 3. 代替,或选择轴的子集(例如axes[:-2]
for i, ax in enumerate(axes):
    ax.plot(x_data[i], y_data[i])
  1. Access each axes by index按索引访问每个轴
axes[0].plot(x_data[0], y_data[0])
axes[1].plot(x_data[1], y_data[1])
axes[2].plot(x_data[2], y_data[2])
axes[3].plot(x_data[3], y_data[3])
  1. Index the data and axes索引数据和轴
for i in range(len(x_data)):
    axes[i].plot(x_data[i], y_data[i])
  1. zip the axes and data together and then iterate through the list of tuples将轴和数据zip在一起,然后遍历元组列表
for ax, x, y in zip(axes, x_data, y_data):
    ax.plot(x, y)

Ouput输出

在此处输入图像描述


  • An option is to assign each axes to a variable, fig, (ax1, ax2, ax3) = plt.subplots(1, 3) .一个选项是将每个轴分配给一个变量fig, (ax1, ax2, ax3) = plt.subplots(1, 3) However, as written, this only works in cases with either nrows=1 or ncols=1 .但是,正如所写,这只适用于nrows=1ncols=1的情况。 This is based on the shape of the array returned by plt.subplots , and quickly becomes cumbersome.这是基于plt.subplots返回的数组的形状,很快就会变得很麻烦。
    • fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2) for a 2 x 2 array. fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)对于 2 x 2 数组。
    • This option is most useful for two subplots (eg: fig, (ax1, ax2) = plt.subplots(1, 2) or fig, (ax1, ax2) = plt.subplots(2, 1) ).此选项对两个子图最有用(例如: fig, (ax1, ax2) = plt.subplots(1, 2)fig, (ax1, ax2) = plt.subplots(2, 1) )。 For more subplots, it's more efficient to flatten and iterate through the array of axes.对于更多子图,展平和迭代轴数组会更有效。

here is a simple solution这是一个简单的解决方案

fig, ax = plt.subplots(nrows=2, ncols=3, sharex=True, sharey=False)
for sp in fig.axes:
    sp.plot(range(10))

输出

Go with the following if you really want to use a loop:如果您真的想使用循环,请执行以下操作:

def plot(data):
    fig = plt.figure(figsize=(100, 100))
    for idx, k in enumerate(data.keys(), 1):
        x, y = data[k].keys(), data[k].values
        plt.subplot(63, 10, idx)
        plt.bar(x, y)  
    plt.show()

Another concise solution is:另一个简洁的解决方案是:

// set up structure of plots
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20,10))

// for plot 1
ax1.set_title('Title A')
ax1.plot(x, y)

// for plot 2
ax2.set_title('Title B')
ax2.plot(x, y)

// for plot 3
ax3.set_title('Title C')
ax3.plot(x,y)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM