简体   繁体   中英

how to draw multiple seaborn `distplot` in a single window?

I am trying to draw multiple seaborn distplot in a single window. I know how to generate a density plot for a single list of data, as shown in my code below ( make_density function). However, I am not sure how to draw multiple seaborn distplots under a single window. Suppose that my list stat_list contains 6 lists as its element, where I want to draw a single distplot from each of these 6 lists under stat_list . How can I draw the 6 displots under a same window, where 3 plots are displayed in each row (so that my output would have 2 rows of 3 plots)?

Thank you,


# function to plot the histogram for a single list.
def make_density(stat_list, color, x_label, y_label):
    
    # Plot formatting
    plt.xlabel(x_label)
    plt.ylabel(y_label)

    # Draw the histogram and fit a density plot.
    sns.distplot(stat_list, hist = True, kde = True,
                 kde_kws = {'linewidth': 2}, color=color)
    
    # get the y-coordinates of the points of the density curve.
    dens_list = sns.distplot(stat_list, hist = False, kde = False,
             kde_kws = {'linewidth': 2}, color = color).get_lines()[0].get_data()[1].tolist()
        
    # find the maximum y-coordinates of the density curve.            
    max_dens_index = dens_list.index(max(dens_list))
    
    # find the mode of the density plot.
    mode_x = sns.distplot(stat_list, hist = False, kde = False,
             kde_kws = {'linewidth': 2}, color = color).get_lines()[0].get_data()[0].tolist()[max_dens_index]
    
    # draw a vertical line at the mode of the histogram.
    plt.axvline(mode_x, color='blue', linestyle='dashed', linewidth=1.5)
    plt.text(mode_x * 1.05, 0.16, 'Mode: {:.4f}'.format(mode_x))

# `stat_list` is a list of 6 lists
# I want to draw histogram and density plot of 
# each of these 6 lists contained in `stat_list` in a single window,
# where each row containing the histograms and densities of the 3 plots
# so in my example, there would be 2 rows of 3 columns of plots (2 x 3 =6).
stat_list = [[0.3,0.5,0.7,0.3,0.5],[0.2,0.1,0.9,0.7,0.4],[0.9,0.8,0.7,0.6,0.5]
          [0.2,0.6,0.75,0.87,0.91],[0.2,0.3,0.8,0.9,0.3],[0.2,0.3,0.8,0.87,0.92]]

I would use seaborn's FacetGrid class for this.

Simple version:

import seaborn
seaborn.set(style='ticks', context='paper')

axgrid = (
    seaborn.load_dataset('titanic')
        .pipe(seaborn.FacetGrid, hue='deck', col='deck', col_wrap=3, sharey=False)
        .map(seaborn.distplot, 'fare')
)

Or with some modifications to your function:

from matplotlib import pyplot
import seaborn
seaborn.set(style='ticks', context='paper')


# function to plot the histogram for a single list.
def make_density(stat, color=None, x_label=None, y_label=None, ax=None, label=None):
   
    if not ax:
        ax = pyplot.gca()
    # Draw the histogram and fit a density plot.
    seaborn.distplot(stat, hist=True, kde=True,
                     kde_kws={'linewidth': 2}, color=color, ax=ax)

    # get the y-coordinates of the points of the density curve.
    dens_list = ax.get_lines()[0].get_data()[1]

    # find the maximum y-coordinates of the density curve.
    max_dens_index = dens_list.argmax()

    # find the mode of the density plot.
    mode_x = ax.get_lines()[0].get_data()[0][max_dens_index]

    # draw a vertical line at the mode of the histogram.
    ax.axvline(mode_x, color=color, linestyle='dashed', linewidth=1.5)
    ymax = ax.get_ylim()[1]
    ax.text(mode_x * 1.1, ymax * 0.16, 'Mode: {:.4f}'.format(mode_x))

    # Plot formatting
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)


axgrid = (
    seaborn.load_dataset('titanic')
        .pipe(seaborn.FacetGrid, hue='deck', col='deck', col_wrap=3, sharey=False)
        .map(make_density, 'fare')
)

在此处输入图片说明

You can create a grid of subplots with fig, axes = plt.subplots(...) . Then you can provide each 'ax' of the returned 'axes' as the ax= parameter of sns.distplot() . Note that you'll need the same ax to set the labels, plt.xlabel() will only change one of the subplots.

Calling sns.distplot three times is not recommended. sns.distplot will add more and more information to the same ax . Also note that you can use numpy functions such as argmax() to efficiently find the maximum without the need to convert to Python lists (which are quite slow when there is a lot of data).

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# function to plot the histogram for a single list.
def make_density(stat, color, x_label, y_label, ax):
    # Draw the histogram and fit a density plot.
    sns.distplot(stat, hist=True, kde=True,
                 kde_kws={'linewidth': 2}, color=color, ax=ax)

    # get the y-coordinates of the points of the density curve.
    dens_list = ax.get_lines()[0].get_data()[1]

    # find the maximum y-coordinates of the density curve.
    max_dens_index = dens_list.argmax()

    # find the mode of the density plot.
    mode_x = ax.get_lines()[0].get_data()[0][max_dens_index]

    # draw a vertical line at the mode of the histogram.
    ax.axvline(mode_x, color='blue', linestyle='dashed', linewidth=1.5)
    ax.text(mode_x * 1.05, 0.16, 'Mode: {:.4f}'.format(mode_x))

    # Plot formatting
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)

stat_list = [[0.3, 0.5, 0.7, 0.3, 0.5], [0.2, 0.1, 0.9, 0.7, 0.4], [0.9, 0.8, 0.7, 0.6, 0.5],
             [0.2, 0.6, 0.75, 0.87, 0.91], [0.2, 0.3, 0.8, 0.9, 0.3], [0.2, 0.3, 0.8, 0.87, 0.92]]
num_subplots = len(stat_list)
ncols = 3
nrows = (num_subplots + ncols - 1) // ncols
fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(ncols * 6, nrows * 5))
colors = plt.cm.tab10.colors
for ax, stat, color in zip(np.ravel(axes), stat_list, colors):
    make_density(stat, color, 'x_label', 'y_label', ax)
for ax in np.ravel(axes)[num_subplots:]:  # remove possible empty subplots at the end
    ax.remove()
plt.show()

结果图

PS: Instead of distplot also histplot (new in Seaborn 0.11 ) could be used. This should give a nicer plot, especially when the data are few and/or discrete.

sns.histplot(stat, kde=True, line_kws={'linewidth': 2}, color=color, ax=ax)

柱状图

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM