简体   繁体   中英

Titles for histograms on diagonal when using seaborn.PairGrid in python for generating a correlation matrix

After investing some time in identifying how the PairGrid function works I am nearly there. Below is the code that generates the plot that I want with one small detail missing in the histfunc. What I want is the title for the histograms plotted on the diagonal. How do I pass the dataframe column names to histfunc? Any ideas appreciated.

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import dcor
import random
from scipy.stats import linregress
from matplotlib import rc

font = {'family' : 'normal',
        'weight' : 'normal',
        'size'   : 4}
rc('font', **font)   

def corrmat(data):
    def cm2inch(value):
        """helper function for plotting. Converts cm to inch"""
        return value/2.54

    def dist_corr(X, Y, pval=True, nruns=100):
        """ Distance correlation with p-value from bootstrapping"""
        dc = dcor.distance_correlation(X, Y)
        pv = dcor.independence.distance_covariance_test(X, Y, exponent=1.0, num_resamples=nruns)[0]
        if pval:
            return (dc, pv)
        else:
            return dc    

    def linreg(X, Y, pval=True):
        """ Linear regression"""
        r2 = linregress(X,Y)[2]**2
        pv = linregress(X,Y)[3]
        if pval:
            return (r2, pv)
        else:
            return r2               

    def scatterfunc(x, y, **kws):
        # scatterplot with spline of deg=5 in red
        plt.scatter(x, y, linewidths=1, facecolor="k", s=10, alpha = 0.5)
        spline = np.polyfit(x, y, 5)
        model = np.poly1d(spline)
        x = np.sort(x)
        plt.plot(x,model(x),'r-')

    def histfunc(x, **kws):
        #  histogram
        plt.hist(x,bins=30,color = "black", ec="white")    
        """
        vvvvvvvvvvvvvvvvvvvv
        here something like 
        plt.title(label) 
        is missing but the **kws only contain label as string not as 
        parameter contaning the column name
        ^^^^^^^^^^^^^^^^^^^
        """

    def corrfunc(x, y, dc=False, **kws):  
        # different sizes, text anc color in relation to r/d values         
        if dc:
            d, p = dist_corr(x,y) 
        else:    
            d, p = linreg(x,y)

        if d<0.25:
            pclr = 'Black'
            fontsize = 16
        elif (d>=0.25) and (d<0.5):
            pclr = 'Blue'
            fontsize = 20
        elif (d>=0.5) and (p<0.75):
            pclr = 'Orange'
            fontsize = 25
        elif (p>0.75):
            pclr = 'Red'
            fontsize = 30

        if p<0.001:
            ptext = "***"
        elif (p>=0.001) and (p<0.01):
            ptext = "**"
        elif (p>=0.01) and (p<0.05):
            ptext = "*"
        elif (p>0.05):
            ptext = "n.sig"

        ax = plt.gca()
        if dc:
            ax.annotate(''.join(['DC: ',str(np.round(d,2)),'\n\n    ',ptext]),
                        xy=(0.3, 0.3), 
                        xycoords=ax.transAxes, 
                        color = pclr, 
                        fontsize = fontsize)
        else:
            ax.annotate(''.join(['r2: ',str(np.round(d,2)),'\n\n    ',ptext]),
                        xy=(0.3, 0.3), 
                        xycoords=ax.transAxes, 
                        color = pclr, 
                        fontsize = fontsize)

        plt.axis('off')


    plt.figure(num=None, figsize=(cm2inch(15), cm2inch(10)), dpi=300, facecolor='w', edgecolor='k')
    g = sns.PairGrid(data, diag_sharey=False)
    g.map_upper(scatterfunc)
    g.map_diag(histfunc)
    g.map_lower(corrfunc)
    plt.tight_layout()
    plt.show()


########

data = pd.DataFrame(np.random.random([1000,10]),columns=[str(i) for i in range(10)])   
for (i,col) in enumerate(data):
    if i > 1:
        if np.random.random()>0.5:
            data[col]= data[col] * data.iloc[:,random.sample(set(np.arange(0,i)),1 )[0]]
corrmat(data)

what it generates is

在此处输入图片说明

Thanks to @ImportanceOfBeingErnest comment here an updated skript for those who might find it useful. I also switched the scatterplot to "lower" so that the axes labels become visible.

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import dcor
import random
from scipy.stats import linregress
from matplotlib import rc

font = {'family' : 'normal',
        'weight' : 'normal',
        'size'   : 16}
rc('font', **font)   

def corrmat(data):
    def cm2inch(value):
        """helper function for plotting. Converts cm to inch"""
        return value/2.54

    def dist_corr(X, Y, pval=True, nruns=100):
        """ Distance correlation with p-value from bootstrapping"""
        dc = dcor.distance_correlation(X, Y)
        pv = dcor.independence.distance_covariance_test(X, Y, exponent=1.0, num_resamples=nruns)[0]
        if pval:
            return (dc, pv)
        else:
            return dc    

    def linreg(X, Y, pval=True):
        """ Linear regression"""
        r2 = linregress(X,Y)[2]**2
        pv = linregress(X,Y)[3]
        if pval:
            return (r2, pv)
        else:
            return r2               

    def scatterfunc(x, y, **kws):
        """ scatterplot with spline of deg=5 in red"""
        plt.scatter(x, y, linewidths=1, facecolor="k", s=10, alpha = 0.5)
        spline = np.polyfit(x, y, 5)
        model = np.poly1d(spline)
        x = np.sort(x)
        plt.plot(x,model(x),'r-')

    def histfunc(x, **kws):
        """ histogram"""
        plt.hist(x,bins=30,color = "black", ec="white")    

    def corrfunc(x, y, dc=False, **kws):  
        """different sizes, text anc color in relation to r/d values
           the dc parameter determines wheter distance correlation or 
           linear regression should be applied"""
        if dc:
            d, p = dist_corr(x,y) 
        else:    
            d, p = linreg(x,y)

        if d<0.25:
            pclr = 'Black'
            fontsize = 16
        elif (d>=0.25) and (d<0.5):
            pclr = 'Blue'
            fontsize = 20
        elif (d>=0.5) and (p<0.75):
            pclr = 'Orange'
            fontsize = 25
        elif (p>0.75):
            pclr = 'Red'
            fontsize = 30

        if p<0.001:
            ptext = "***"
        elif (p>=0.001) and (p<0.01):
            ptext = "**"
        elif (p>=0.01) and (p<0.05):
            ptext = "*"
        elif (p>0.05):
            ptext = "n.sig"

        ax = plt.gca()
        if dc:
            ax.annotate(''.join(['DC: ',str(np.round(d,2)),'\n\n    ',ptext]),
                        xy=(0.3, 0.3), 
                        xycoords=ax.transAxes, 
                        color = pclr, 
                        fontsize = fontsize)
        else:
            ax.annotate(''.join(['r2: ',str(np.round(d,2)),'\n\n    ',ptext]),
                        xy=(0.3, 0.3), 
                        xycoords=ax.transAxes, 
                        color = pclr, 
                        fontsize = fontsize)

        plt.axis('off')

    def make_diag_titles(g,titles):
        for (i,row) in enumerate(g.axes):
            g.axes[i][i].title.set_text(titles[i])
        return g
    ###
    # here the plot is put together
    plt.figure(num=None, figsize=(cm2inch(15), cm2inch(10)), dpi=300, facecolor='w', edgecolor='k')
    g = sns.PairGrid(data, diag_sharey=False)
    g.map_lower(scatterfunc)
    g.map_diag(histfunc)
    g.map_upper(corrfunc)
    g = make_diag_titles(g, data.columns)
    plt.tight_layout()
    plt.show()


########

data = pd.DataFrame(np.random.random([1000,10]),columns=[str(i) for i in range(10)])   
for (i,col) in enumerate(data):
    if i > 1:
        if np.random.random()>0.5:
            data[col]= data[col] * data.iloc[:,random.sample(set(np.arange(0,i)),1 )[0]]
corrmat(data)

在此处输入图片说明

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM