简体   繁体   中英

Legends disappear when {“hist”:False} in seaborn distplot

I have the following function:

Say hue="animals have three categories dog,bird,horse and we have two dataframes df_m and df_f consisting of data of male animals and women animals only, respectively.

The function plots three distplot of y (eg y="weight" ) one for each hue={dog,bird,horse} . In each subplot we plot df_m[y] and df_f[y] such that I can compare the weight of male dogs/female dogs, male birds/female birds, male horses/female horses.

If I set distkwargs={"hist":False} when calling the function the legends ["F","M"] disappears, for some reason. Having distkwargs={"hist":True}` shows the legends

    def plot_multi_kde_cat(self,dfs,y,hue,subkwargs={},distkwargs={},legends=[]):
        """
        Create a subplot multi_kde with categories in the same plot

        dfs: List 
         - DataFrames for each category e.g one for male and one for females
       hue: string
        - column for which each category is plotted (in each subplot)
        """

        hues = dfs[0][hue].cat.categories
        if len(hues)==2: #Only two categories
            fig,axes = plt.subplots(1,2,**subkwargs) #Get axes and flatten them
            axes=axes.flatten()
            for ax,hu in zip(axes,hues):
                
                for df in dfs:
                    sns.distplot(df.loc[df[hue]==hu,y],ax=ax,**distkwargs)
                
                ax.set_title(f"Segment: {hu}")
                ax.legend(legends)
            
            

        else: #More than two categories: create a square grid and remove unsused axes
            n_rows = int(np.ceil(np.sqrt(len(hues)))) #number of rows
            fig,axes = plt.subplots(n_rows,n_rows,**subkwargs) 

            axes = axes.flatten()

            for ax,hu in zip(axes,hues):
                for df in dfs:
                    
                    sns.distplot(df.loc[df[hue]==hu,y],ax=ax,**distkwargs)
                    
                ax.set_title(f"Segment: {hu}")
                ax.legend(legends)
            
            
            n_remove = len(axes)-len(hues) #number of axes to remove
            
            if n_remove>0:
                for ax in axes[-n_remove:]:
                    ax.set_visible(False)
        
        fig.tight_layout()
        
        return fig,axes
        

You can work around the problem by explicitly providing the label to the distplot . This forces a legend entry for each distplot . ax.legend() then already gets the correct labels.

Here is some minimal sample code to illustrate how everything works together:

from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np

def plot_multi_kde_cat(dfs, y, hue, subkwargs={}, distkwargs={}, legends=[]):
    hues = np.unique(dfs[0][hue])
    fig, axes = plt.subplots(1, len(hues), **subkwargs)
    axes = axes.flatten()
    for ax, hu in zip(axes, hues):
        for df, legend_label in zip(dfs, legends):
            sns.distplot(df.loc[df[hue] == hu, y], ax=ax, label=legend_label, **distkwargs)
        ax.set_title(f"Segment: {hu}")
        ax.legend()

N = 20
df_m = pd.DataFrame({'animal': np.random.choice(['tiger', 'horse'], N), 'weight': np.random.uniform(100, 200, N)})
df_f = pd.DataFrame({'animal': np.random.choice(['tiger', 'horse'], N), 'weight': np.random.uniform(80, 160, N)})
plot_multi_kde_cat([df_m, df_f], 'weight', 'animal',
                   subkwargs={}, distkwargs={'hist': False}, legends=['male', 'female'])
plt.show()

示例图

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM