[英]How to plot in multiple subplots
I am a little confused about how this code works:我对这段代码的工作方式有点困惑:
fig, axes = plt.subplots(nrows=2, ncols=2)
plt.show()
How does the fig, axes work in this case?在这种情况下,无花果轴如何工作? What does it do?它有什么作用?
Also why wouldn't this work to do the same thing:还有为什么这不能做同样的事情:
fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)
There are several ways to do it.有几种方法可以做到这一点。 The subplots
method creates the figure along with the subplots that are then stored in the ax
array. subplots
方法创建图形以及随后存储在ax
数组中的子图。 For example:例如:
import matplotlib.pyplot as plt
x = range(10)
y = range(10)
fig, ax = plt.subplots(nrows=2, ncols=2)
for row in ax:
for col in row:
col.plot(x, y)
plt.show()
However, something like this will also work, it's not so "clean" though since you are creating a figure with subplots and then add on top of them:然而,这样的事情也可以工作,但它不是那么“干净”,因为你正在创建一个带有子图的图形,然后在它们之上添加:
fig = plt.figure()
plt.subplot(2, 2, 1)
plt.plot(x, y)
plt.subplot(2, 2, 2)
plt.plot(x, y)
plt.subplot(2, 2, 3)
plt.plot(x, y)
plt.subplot(2, 2, 4)
plt.plot(x, y)
plt.show()
You can also unpack the axes in the subplots call您还可以在 subplots 调用中解压缩轴
And set whether you want to share the x and y axes between the subplots并设置是否要在子图之间共享 x 和 y 轴
Like this:像这样:
import matplotlib.pyplot as plt
# fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
ax1, ax2, ax3, ax4 = axes.flatten()
ax1.plot(range(10), 'r')
ax2.plot(range(10), 'b')
ax3.plot(range(10), 'g')
ax4.plot(range(10), 'k')
plt.show()
You might be interested in the fact that as of matplotlib version 2.1 the second code from the question works fine as well.您可能对从 matplotlib 版本 2.1 开始该问题的第二个代码也可以正常工作这一事实感兴趣。
From the change log :从更改日志:
Figure class now has subplots method The Figure class now has a subplots() method which behaves the same as pyplot.subplots() but on an existing figure. Figure 类现在有 subplots 方法 Figure 类现在有一个 subplots() 方法,其行为与 pyplot.subplots() 相同,但在现有图形上。
Example:例子:
import matplotlib.pyplot as plt
fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)
plt.show()
Read the documentation: matplotlib.pyplot.subplots阅读文档: matplotlib.pyplot.subplots
pyplot.subplots()
returns a tuple fig, ax
which is unpacked in two variables using the notation pyplot.subplots()
返回一个元组fig, ax
使用符号解包在两个变量中
fig, axes = plt.subplots(nrows=2, ncols=2)
The code:编码:
fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)
does not work because subplots()
is a function in pyplot
not a member of the object Figure
.不起作用,因为 subplots subplots()
是pyplot
中的一个函数,而不是对象Figure
的成员。
Iterating through all subplots sequentially:依次遍历所有子图:
fig, axes = plt.subplots(nrows, ncols)
for ax in axes.flatten():
ax.plot(x,y)
Accessing a specific index:访问特定索引:
for row in range(nrows):
for col in range(ncols):
axes[row,col].plot(x[row], y[col])
pandas
, which uses matplotlib
as the default plotting backend.这个答案适用于带有pandas
的子图,它使用matplotlib
作为默认的绘图后端。pandas.DataFrame
以下是创建以pandas.DataFrame
开头的子图的四个选项
python 3.8.11
, pandas 1.3.2
, matplotlib 3.4.3
, seaborn 0.11.2
在python 3.8.11
、 pandas 1.3.2
、 matplotlib 3.4.3
、 seaborn 0.11.2
import seaborn as sns # data only
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# wide dataframe
df = sns.load_dataset('planets').iloc[:, 2:5]
orbital_period mass distance
0 269.300 7.10 77.40
1 874.774 2.21 56.95
2 763.000 2.60 19.84
3 326.030 19.40 110.62
4 516.220 10.50 119.47
# long dataframe
dfm = sns.load_dataset('planets').iloc[:, 2:5].melt()
variable value
0 orbital_period 269.300
1 orbital_period 874.774
2 orbital_period 763.000
3 orbital_period 326.030
4 orbital_period 516.220
subplots=True
and layout
, for each column 1. subplots=True
和layout
,对于每一列subplots=True
and layout=(rows, cols)
in pandas.DataFrame.plot
在pandas.DataFrame.plot
中使用参数subplots=True
和layout=(rows, cols)
kind='density'
, but there are different options for kind
, and this applies to them all.此示例使用kind='density'
,但kind
有不同的选项,这适用于所有选项。 Without specifying kind
, a line plot is the default.如果不指定kind
,则默认使用折线图。ax
is array of AxesSubplot
returned by pandas.DataFrame.plot
ax
是AxesSubplot
返回的pandas.DataFrame.plot
数组Figure
object , if needed.如果需要,请参阅如何获取Figure
对象。
axes = df.plot(kind='density', subplots=True, layout=(2, 2), sharex=False, figsize=(10, 6))
# extract the figure object; only used for tight_layout in this example
fig = axes[0][0].get_figure()
# set the individual titles
for ax, title in zip(axes.ravel(), df.columns):
ax.set_title(title)
fig.tight_layout()
plt.show()
plt.subplots
, for each column 2. plt.subplots
,对于每一列Axes
with matplotlib.pyplot.subplots
and then pass axes[i, j]
or axes[n]
to the ax
parameter.使用matplotlib.pyplot.subplots
创建一个Axes
数组,然后将axes[i, j]
或axes[n]
传递给ax
参数。
pandas.DataFrame.plot
, but can use other axes
level plot calls as a substitute (eg sns.kdeplot
, plt.plot
, etc.)此选项使用pandas.DataFrame.plot
,但可以使用其他axes
级别绘图调用作为替代(例如sns.kdeplot
、 plt.plot
等)Axes
into one dimension with .ravel
or .flatten
.使用.ravel
或.flatten
将Axes
的子图数组折叠成一维是最简单的。 See .ravel
vs .flatten
.请参阅.ravel
与.flatten
。axes
, that need to be iterate through, are combined with .zip
(eg cols
, axes
, colors
, palette
, etc.).适用于每个axes
的任何变量,需要迭代,都与.zip
组合(例如cols
、 axes
、 colors
、 palette
等)。 Each object must be the same length.每个对象的长度必须相同。fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6)) # define the figure and subplots
axes = axes.ravel() # array to 1D
cols = df.columns # create a list of dataframe columns to use
colors = ['tab:blue', 'tab:orange', 'tab:green'] # list of colors for each subplot, otherwise all subplots will be one color
for col, color, ax in zip(cols, colors, axes):
df[col].plot(kind='density', ax=ax, color=color, label=col, title=col)
ax.legend()
fig.delaxes(axes[3]) # delete the empty subplot
fig.tight_layout()
plt.show()
plt.subplots
, for each group in .groupby
3. plt.subplots
,对于.groupby
中的每个组color
and axes
to a .groupby
object.这与 2. 类似,不同之处在于它将color
和axes
压缩到.groupby
对象。fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6)) # define the figure and subplots
axes = axes.ravel() # array to 1D
dfg = dfm.groupby('variable') # get data for each unique value in the first column
colors = ['tab:blue', 'tab:orange', 'tab:green'] # list of colors for each subplot, otherwise all subplots will be one color
for (group, data), color, ax in zip(dfg, colors, axes):
data.plot(kind='density', ax=ax, color=color, title=group, legend=False)
fig.delaxes(axes[3]) # delete the empty subplot
fig.tight_layout()
plt.show()
seaborn
figure-level plot 4. seaborn
人物级剧情seaborn
figure-level plot, and use the col
or row
parameter.使用seaborn
图形级别图,并使用col
或row
参数。 seaborn
is a high-level API for matplotlib
. seaborn
是matplotlib
的高级 API。 See seaborn: API reference请参阅seaborn:API 参考p = sns.displot(data=dfm, kind='kde', col='variable', col_wrap=2, x='value', hue='variable',
facet_kws={'sharey': False, 'sharex': False}, height=3.5, aspect=1.75)
sns.move_legend(p, "upper left", bbox_to_anchor=(.55, .45))
The other answers are great, this answer is a combination which might be useful.其他答案很好,这个答案是一个可能有用的组合。
import numpy as np
import matplotlib.pyplot as plt
# Optional: define x for all the sub-plots
x = np.linspace(0,2*np.pi,100)
# (1) Prepare the figure infrastructure
fig, ax_array = plt.subplots(nrows=2, ncols=2)
# flatten the array of axes, which makes them easier to iterate through and assign
ax_array = ax_array.flatten()
# (2) Plot loop
for i, ax in enumerate(ax_array):
ax.plot(x , np.sin(x + np.pi/2*i))
#ax.set_title(f'plot {i}')
# Optional: main title
plt.suptitle('Plots')
ax_array
can be individually indexed from 0
through nrows x ncols -1
(eg ax_array[0]
, ax_array[1]
, ax_array[2]
, ax_array[3]
).一旦展平,每个ax_array
可以从0
到nrows x ncols -1
单独索引(例如ax_array[0]
、 ax_array[1]
、 ax_array[2]
、 ax_array[3]
)。axes
array to 1D将axes
数组转换为 1Dplt.subplots(nrows, ncols)
, where both nrows and ncols is greater than 1, returns a nested array of <AxesSubplot:>
objects.使用plt.subplots(nrows, ncols)
生成子图,其中 nrows 和 ncols均大于 1,返回<AxesSubplot:>
对象的嵌套数组。
axes
in cases where either nrows=1
or ncols=1
, because axes
will already be 1 dimensional, which is a result of the default parameter squeeze=True
在nrows=1
或ncols=1
的情况下,没有必要展平axes
,因为axes
已经是一维的,这是默认参数squeeze=True
的结果.ravel()
, .flatten()
, or .flat
.访问对象的最简单方法是使用.ravel()
、 .flatten()
或.flat
将数组转换为一维。
.ravel
vs. .flatten
.ravel
与.flatten
flatten
always returns a copy. flatten
总是返回一个副本。ravel
returns a view of the original array whenever possible. ravel
尽可能返回原始数组的视图。axes
is converted to 1-d, there are a number of ways to plot.将axes
数组转换为 1-d 后,有多种绘图方法。ax=
parameter (eg sns.barplot(..., ax=ax[0])
.该答案与具有ax=
参数(例如sns.barplot(..., ax=ax[0])
的 seaborn 轴级图相关。
seaborn
is a high-level API for matplotlib
. seaborn
是matplotlib
的高级 API。 See Figure-level vs. axes-level functions and seaborn is not plotting within defined subplots请参阅图形级与轴级函数,并且seaborn 未在定义的子图中绘制import matplotlib.pyplot as plt
import numpy as np # sample data only
# example of data
rads = np.arange(0, 2*np.pi, 0.01)
y_data = np.array([np.sin(t*rads) for t in range(1, 5)])
x_data = [rads, rads, rads, rads]
# Generate figure and its subplots
fig, axes = plt.subplots(nrows=2, ncols=2)
# axes before
array([[<AxesSubplot:>, <AxesSubplot:>],
[<AxesSubplot:>, <AxesSubplot:>]], dtype=object)
# convert the array to 1 dimension
axes = axes.ravel()
# axes after
array([<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
dtype=object)
IndexError: list index out of range
如果子图比数据多,这将导致IndexError: list index out of range
axes[:-2]
)尝试选项 3. 代替,或选择轴的子集(例如axes[:-2]
)for i, ax in enumerate(axes):
ax.plot(x_data[i], y_data[i])
axes[0].plot(x_data[0], y_data[0])
axes[1].plot(x_data[1], y_data[1])
axes[2].plot(x_data[2], y_data[2])
axes[3].plot(x_data[3], y_data[3])
for i in range(len(x_data)):
axes[i].plot(x_data[i], y_data[i])
zip
the axes and data together and then iterate through the list of tuples将轴和数据zip
在一起,然后遍历元组列表for ax, x, y in zip(axes, x_data, y_data):
ax.plot(x, y)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
.一个选项是将每个轴分配给一个变量fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
。 However, as written, this only works in cases with either nrows=1
or ncols=1
.但是,正如所写,这只适用于nrows=1
或ncols=1
的情况。 This is based on the shape of the array returned by plt.subplots
, and quickly becomes cumbersome.这是基于plt.subplots
返回的数组的形状,很快就会变得很麻烦。
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
for a 2 x 2 array. fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
对于 2 x 2 数组。fig, (ax1, ax2) = plt.subplots(1, 2)
or fig, (ax1, ax2) = plt.subplots(2, 1)
).此选项对两个子图最有用(例如: fig, (ax1, ax2) = plt.subplots(1, 2)
或fig, (ax1, ax2) = plt.subplots(2, 1)
)。 For more subplots, it's more efficient to flatten and iterate through the array of axes.对于更多子图,展平和迭代轴数组会更有效。Go with the following if you really want to use a loop:如果您真的想使用循环,请执行以下操作:
def plot(data):
fig = plt.figure(figsize=(100, 100))
for idx, k in enumerate(data.keys(), 1):
x, y = data[k].keys(), data[k].values
plt.subplot(63, 10, idx)
plt.bar(x, y)
plt.show()
Another concise solution is:另一个简洁的解决方案是:
// set up structure of plots
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20,10))
// for plot 1
ax1.set_title('Title A')
ax1.plot(x, y)
// for plot 2
ax2.set_title('Title B')
ax2.plot(x, y)
// for plot 3
ax3.set_title('Title C')
ax3.plot(x,y)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.