简体   繁体   English

如何在单个窗口中绘制多个seaborn`distplot`?

[英]how to draw multiple seaborn `distplot` in a single window?

I am trying to draw multiple seaborn distplot in a single window.我试图在一个窗口中绘制多个 seaborn distplot I know how to generate a density plot for a single list of data, as shown in my code below ( make_density function).我知道如何为单个数据列表生成密度图,如下面的代码所示( make_density函数)。 However, I am not sure how to draw multiple seaborn distplots under a single window.但是,我不确定如何在单个窗口下绘制多个 seaborn distplots Suppose that my list stat_list contains 6 lists as its element, where I want to draw a single distplot from each of these 6 lists under stat_list .假设我的列表stat_list包含6只列出作为其元素,在这里我要画一个distplot从每个下这6名名单stat_list How can I draw the 6 displots under a same window, where 3 plots are displayed in each row (so that my output would have 2 rows of 3 plots)?我怎么可以得出6个displots下,同一个窗口,其中3个地块显示每行(让我的输出将有2排3个区)的?

Thank you,谢谢,


# function to plot the histogram for a single list.
def make_density(stat_list, color, x_label, y_label):
    
    # Plot formatting
    plt.xlabel(x_label)
    plt.ylabel(y_label)

    # Draw the histogram and fit a density plot.
    sns.distplot(stat_list, hist = True, kde = True,
                 kde_kws = {'linewidth': 2}, color=color)
    
    # get the y-coordinates of the points of the density curve.
    dens_list = sns.distplot(stat_list, hist = False, kde = False,
             kde_kws = {'linewidth': 2}, color = color).get_lines()[0].get_data()[1].tolist()
        
    # find the maximum y-coordinates of the density curve.            
    max_dens_index = dens_list.index(max(dens_list))
    
    # find the mode of the density plot.
    mode_x = sns.distplot(stat_list, hist = False, kde = False,
             kde_kws = {'linewidth': 2}, color = color).get_lines()[0].get_data()[0].tolist()[max_dens_index]
    
    # draw a vertical line at the mode of the histogram.
    plt.axvline(mode_x, color='blue', linestyle='dashed', linewidth=1.5)
    plt.text(mode_x * 1.05, 0.16, 'Mode: {:.4f}'.format(mode_x))

# `stat_list` is a list of 6 lists
# I want to draw histogram and density plot of 
# each of these 6 lists contained in `stat_list` in a single window,
# where each row containing the histograms and densities of the 3 plots
# so in my example, there would be 2 rows of 3 columns of plots (2 x 3 =6).
stat_list = [[0.3,0.5,0.7,0.3,0.5],[0.2,0.1,0.9,0.7,0.4],[0.9,0.8,0.7,0.6,0.5]
          [0.2,0.6,0.75,0.87,0.91],[0.2,0.3,0.8,0.9,0.3],[0.2,0.3,0.8,0.87,0.92]]

I would use seaborn's FacetGrid class for this.为此,我会使用 seaborn 的FacetGrid类。

Simple version:简单版:

import seaborn
seaborn.set(style='ticks', context='paper')

axgrid = (
    seaborn.load_dataset('titanic')
        .pipe(seaborn.FacetGrid, hue='deck', col='deck', col_wrap=3, sharey=False)
        .map(seaborn.distplot, 'fare')
)

Or with some modifications to your function:或者对您的功能进行一些修改:

from matplotlib import pyplot
import seaborn
seaborn.set(style='ticks', context='paper')


# function to plot the histogram for a single list.
def make_density(stat, color=None, x_label=None, y_label=None, ax=None, label=None):
   
    if not ax:
        ax = pyplot.gca()
    # Draw the histogram and fit a density plot.
    seaborn.distplot(stat, hist=True, kde=True,
                     kde_kws={'linewidth': 2}, color=color, ax=ax)

    # get the y-coordinates of the points of the density curve.
    dens_list = ax.get_lines()[0].get_data()[1]

    # find the maximum y-coordinates of the density curve.
    max_dens_index = dens_list.argmax()

    # find the mode of the density plot.
    mode_x = ax.get_lines()[0].get_data()[0][max_dens_index]

    # draw a vertical line at the mode of the histogram.
    ax.axvline(mode_x, color=color, linestyle='dashed', linewidth=1.5)
    ymax = ax.get_ylim()[1]
    ax.text(mode_x * 1.1, ymax * 0.16, 'Mode: {:.4f}'.format(mode_x))

    # Plot formatting
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)


axgrid = (
    seaborn.load_dataset('titanic')
        .pipe(seaborn.FacetGrid, hue='deck', col='deck', col_wrap=3, sharey=False)
        .map(make_density, 'fare')
)

在此处输入图片说明

You can create a grid of subplots with fig, axes = plt.subplots(...) .您可以使用fig, axes = plt.subplots(...)创建子fig, axes = plt.subplots(...)网格。 Then you can provide each 'ax' of the returned 'axes' as the ax= parameter of sns.distplot() .然后,您可以提供返回的“轴”的每个“轴”作为sns.distplot()ax=参数。 Note that you'll need the same ax to set the labels, plt.xlabel() will only change one of the subplots.请注意,您需要使用相同的ax来设置标签, plt.xlabel()只会更改其中一个子图。

Calling sns.distplot three times is not recommended.不建议调用sns.distplot三次。 sns.distplot will add more and more information to the same ax . sns.distplot将为同一个ax添加越来越多的信息。 Also note that you can use numpy functions such as argmax() to efficiently find the maximum without the need to convert to Python lists (which are quite slow when there is a lot of data).另请注意,您可以使用诸如argmax()类的 numpy 函数来有效地找到最大值,而无需转换为 Python 列表(当有大量数据时会很慢)。

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# function to plot the histogram for a single list.
def make_density(stat, color, x_label, y_label, ax):
    # Draw the histogram and fit a density plot.
    sns.distplot(stat, hist=True, kde=True,
                 kde_kws={'linewidth': 2}, color=color, ax=ax)

    # get the y-coordinates of the points of the density curve.
    dens_list = ax.get_lines()[0].get_data()[1]

    # find the maximum y-coordinates of the density curve.
    max_dens_index = dens_list.argmax()

    # find the mode of the density plot.
    mode_x = ax.get_lines()[0].get_data()[0][max_dens_index]

    # draw a vertical line at the mode of the histogram.
    ax.axvline(mode_x, color='blue', linestyle='dashed', linewidth=1.5)
    ax.text(mode_x * 1.05, 0.16, 'Mode: {:.4f}'.format(mode_x))

    # Plot formatting
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)

stat_list = [[0.3, 0.5, 0.7, 0.3, 0.5], [0.2, 0.1, 0.9, 0.7, 0.4], [0.9, 0.8, 0.7, 0.6, 0.5],
             [0.2, 0.6, 0.75, 0.87, 0.91], [0.2, 0.3, 0.8, 0.9, 0.3], [0.2, 0.3, 0.8, 0.87, 0.92]]
num_subplots = len(stat_list)
ncols = 3
nrows = (num_subplots + ncols - 1) // ncols
fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(ncols * 6, nrows * 5))
colors = plt.cm.tab10.colors
for ax, stat, color in zip(np.ravel(axes), stat_list, colors):
    make_density(stat, color, 'x_label', 'y_label', ax)
for ax in np.ravel(axes)[num_subplots:]:  # remove possible empty subplots at the end
    ax.remove()
plt.show()

结果图

PS: Instead of distplot also histplot (new in Seaborn 0.11 ) could be used. PS:也可以使用histplot (Seaborn 0.11新内容)代替distplot This should give a nicer plot, especially when the data are few and/or discrete.这应该会给出一个更好的图,尤其是当数据很少和/或离散时。

sns.histplot(stat, kde=True, line_kws={'linewidth': 2}, color=color, ax=ax)

柱状图

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM