简体   繁体   中英

How to plot multiple seaborn.distplot in a single figure

I want to plot multiple seaborn distplot under a same window, where each plot has the same x and y grid. My attempt is shown below, which does not work.

# function to plot the density curve of the 200 Median Stn. MC-losses
def make_density(stat_list,color, layer_num):

    num_subplots = len(stat_list)
    ncols = 3
    nrows = (num_subplots + ncols - 1) // ncols
    fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(ncols * 6, nrows * 5))
    
    for i in range(len(stat_list)):
        
        # Plot formatting
        plt.title('Layer ' + layer_num)
        plt.xlabel('Median Stn. MC-Loss')
        plt.ylabel('Density')
        plt.xlim(-0.2,0.05)
        plt.ylim(0, 85)
        min_ylim, max_ylim = plt.ylim()
    
        # Draw the density plot.
        sns.distplot(stat_list, hist = True, kde = True,
                 kde_kws = {'linewidth': 2}, color=color)

# `stat_list` is a list of 6 lists
# I want to draw histogram and density plot of 
# each of these 6 lists contained in `stat_list` in a single window,
# where each row containing the histograms and densities of the 3 plots
# so in my example, there would be 2 rows of 3 columns of plots (2 x 3 =6).
stat_list = [[0.3,0.5,0.7,0.3,0.5],[0.2,0.1,0.9,0.7,0.4],[0.9,0.8,0.7,0.6,0.5]
          [0.2,0.6,0.75,0.87,0.91],[0.2,0.3,0.8,0.9,0.3],[0.2,0.3,0.8,0.87,0.92]]

How can I modify my function to draw multiple distplot under the same window, where the x and y grid for each displayed plot is identical?

Thank you,

PS: Aside, I want the 6 distplots to have identical color, preferably green for all of them.

  • The easiest method is to load the data into pandas and then use seaborn.displot .
  • .displot replaces .distplot in seaborn version 0.11.0
    • Technically, what you would have wanted before, is a FacetGrid mapped with distplot .
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# data
stat_list = [[0.3,0.5,0.7,0.3,0.5], [0.2,0.1,0.9,0.7,0.4], [0.9,0.8,0.7,0.6,0.5], [0.2,0.6,0.75,0.87,0.91], [0.2,0.3,0.8,0.9,0.3], [0.2,0.3,0.8,0.87,0.92]]

# load the data into pandas and then transpose it for the correct column data
df = pd.DataFrame(stat_list).T

# name the columns; specify a layer number
df.columns = ['A', 'B', 'C', 'D', 'E', 'F']

# now stack the data into a long (tidy) format
dfl = df.stack().reset_index(level=1).rename(columns={'level_1': 'Layer', 0: 'Median Stn. MC-Loss'})

# plot a displot
g = sns.displot(data=dfl, x='Median Stn. MC-Loss', col='Layer', col_wrap=3, kde=True, color='green')
g.set_axis_labels(y_var='Density')
g.set(xlim=(0, 1.0), ylim=(0, 3.0))

在此处输入图片说明

sns.FacetGrid and sns.distplot

  • .distplot is deprecated
p = sns.FacetGrid(data=dfl, col='Layer', col_wrap=3, height=5)
p.map(sns.distplot, 'Median Stn. MC-Loss', bins=5, kde=True, color='green')
p.set(xlim=(0, 1.0))

There is a general solution as a free library of seventeen matplotlib graphics utilities + user guide here: https://www.mlbridgeresearch.com/products/free-article-2 . I got tired of interrupting my research to write utility software, so I've accumulated libraries that address common needs. The code is well-documented, and it works well. The example calls histogram_grid() in the library, which returns the plot grid on a matplotlib Figure. Because histograms generally do not have the same ranges, the standard method does not accommodate exactly what you asked for, so the adjustments are made to the returned Figure.

import pandas as pd
import matplotlib.pyplot as plt

from statistics_utilities import histogram_grid


stat_list = [[0.3, 0.5, 0.7, 0.3, 0.5], [0.2, 0.1, 0.9, 0.7, 0.4], [0.9, 0.8, 0.7, 0.6, 0.5],
            [0.2, 0.6, 0.75, 0.87, 0.91], [0.2, 0.3, 0.8, 0.9, 0.3], [0.2, 0.3, 0.8, 0.87, 0.92]]

df = pd.DataFrame(stat_list).transpose()
# histogram_grid() accepts only a DataFrame and requires named columns.
df.columns = ['x1', 'x2', 'x3', 'x4', 'x5', 'x6']

# If kde is True, the plot is a density plot no matter how hist_type is set.
hist_type = 'density'
variable_names = df.columns
bins = 3
fig = histogram_grid(df, bins=bins, hist_type=hist_type, kde=True, legend=False,
                     title='test title', variable_names=variable_names,
                     n_gridcolumns=3, height=6, width=10)
fig.subplots_adjust(wspace=.2, left=0.035, right=.95, bottom=.13)

# the adjustments to the axes on the 2 x 3 grid plot.
# Turn of x-axis labels/ticks in the top row and y-axis
# labels/ticks in the 1st column.
axes_list = fig.axes            # get a list of Axes in Figure
ax_index = 0
modify_xaxes_indexes = [0, 1, 2]
modify_yaxes_indexes = [1, 2, 4, 5]
for ax in axes_list:
    ax.set_xlabel(None)
    ax.set_ylabel(None)
    # normally, the xlim() would be calculated but I can see that
    # .1 <= x <= .92 and similarly the densities are 0 <= y <= 3.
    ax.set_xlim(.05, .95)
    ax.set_ylim(0, 3)
    if ax_index in modify_xaxes_indexes:
        ax.tick_params(
            axis='x',  # changes apply to the x-axis
            which='both',  # both major and minor ticks are affected
            bottom=False,  # ticks along the bottom edge are off
            top=False,  # ticks along the top edge are off
            labelbottom=False)  # labels along the bottom edge are off
    if ax_index in modify_yaxes_indexes:
        ax.tick_params(
            axis='y',  # changes apply to the x-axis
            which='both',  # both major and minor ticks are affected
            left=False,  # ticks along the bottom edge are off
            right=False,  # ticks along the top edge are off
            labelleft=False)  # labels along the bottom edge are off
    ax_index += 1

plt.show()
plt.close()

直方图网格图

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM