After investing some time in identifying how the PairGrid function works I am nearly there. Below is the code that generates the plot that I want with one small detail missing in the histfunc. What I want is the title for the histograms plotted on the diagonal. How do I pass the dataframe column names to histfunc? Any ideas appreciated.
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import dcor
import random
from scipy.stats import linregress
from matplotlib import rc
font = {'family' : 'normal',
'weight' : 'normal',
'size' : 4}
rc('font', **font)
def corrmat(data):
def cm2inch(value):
"""helper function for plotting. Converts cm to inch"""
return value/2.54
def dist_corr(X, Y, pval=True, nruns=100):
""" Distance correlation with p-value from bootstrapping"""
dc = dcor.distance_correlation(X, Y)
pv = dcor.independence.distance_covariance_test(X, Y, exponent=1.0, num_resamples=nruns)[0]
if pval:
return (dc, pv)
else:
return dc
def linreg(X, Y, pval=True):
""" Linear regression"""
r2 = linregress(X,Y)[2]**2
pv = linregress(X,Y)[3]
if pval:
return (r2, pv)
else:
return r2
def scatterfunc(x, y, **kws):
# scatterplot with spline of deg=5 in red
plt.scatter(x, y, linewidths=1, facecolor="k", s=10, alpha = 0.5)
spline = np.polyfit(x, y, 5)
model = np.poly1d(spline)
x = np.sort(x)
plt.plot(x,model(x),'r-')
def histfunc(x, **kws):
# histogram
plt.hist(x,bins=30,color = "black", ec="white")
"""
vvvvvvvvvvvvvvvvvvvv
here something like
plt.title(label)
is missing but the **kws only contain label as string not as
parameter contaning the column name
^^^^^^^^^^^^^^^^^^^
"""
def corrfunc(x, y, dc=False, **kws):
# different sizes, text anc color in relation to r/d values
if dc:
d, p = dist_corr(x,y)
else:
d, p = linreg(x,y)
if d<0.25:
pclr = 'Black'
fontsize = 16
elif (d>=0.25) and (d<0.5):
pclr = 'Blue'
fontsize = 20
elif (d>=0.5) and (p<0.75):
pclr = 'Orange'
fontsize = 25
elif (p>0.75):
pclr = 'Red'
fontsize = 30
if p<0.001:
ptext = "***"
elif (p>=0.001) and (p<0.01):
ptext = "**"
elif (p>=0.01) and (p<0.05):
ptext = "*"
elif (p>0.05):
ptext = "n.sig"
ax = plt.gca()
if dc:
ax.annotate(''.join(['DC: ',str(np.round(d,2)),'\n\n ',ptext]),
xy=(0.3, 0.3),
xycoords=ax.transAxes,
color = pclr,
fontsize = fontsize)
else:
ax.annotate(''.join(['r2: ',str(np.round(d,2)),'\n\n ',ptext]),
xy=(0.3, 0.3),
xycoords=ax.transAxes,
color = pclr,
fontsize = fontsize)
plt.axis('off')
plt.figure(num=None, figsize=(cm2inch(15), cm2inch(10)), dpi=300, facecolor='w', edgecolor='k')
g = sns.PairGrid(data, diag_sharey=False)
g.map_upper(scatterfunc)
g.map_diag(histfunc)
g.map_lower(corrfunc)
plt.tight_layout()
plt.show()
########
data = pd.DataFrame(np.random.random([1000,10]),columns=[str(i) for i in range(10)])
for (i,col) in enumerate(data):
if i > 1:
if np.random.random()>0.5:
data[col]= data[col] * data.iloc[:,random.sample(set(np.arange(0,i)),1 )[0]]
corrmat(data)
what it generates is
Thanks to @ImportanceOfBeingErnest comment here an updated skript for those who might find it useful. I also switched the scatterplot to "lower" so that the axes labels become visible.
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import dcor
import random
from scipy.stats import linregress
from matplotlib import rc
font = {'family' : 'normal',
'weight' : 'normal',
'size' : 16}
rc('font', **font)
def corrmat(data):
def cm2inch(value):
"""helper function for plotting. Converts cm to inch"""
return value/2.54
def dist_corr(X, Y, pval=True, nruns=100):
""" Distance correlation with p-value from bootstrapping"""
dc = dcor.distance_correlation(X, Y)
pv = dcor.independence.distance_covariance_test(X, Y, exponent=1.0, num_resamples=nruns)[0]
if pval:
return (dc, pv)
else:
return dc
def linreg(X, Y, pval=True):
""" Linear regression"""
r2 = linregress(X,Y)[2]**2
pv = linregress(X,Y)[3]
if pval:
return (r2, pv)
else:
return r2
def scatterfunc(x, y, **kws):
""" scatterplot with spline of deg=5 in red"""
plt.scatter(x, y, linewidths=1, facecolor="k", s=10, alpha = 0.5)
spline = np.polyfit(x, y, 5)
model = np.poly1d(spline)
x = np.sort(x)
plt.plot(x,model(x),'r-')
def histfunc(x, **kws):
""" histogram"""
plt.hist(x,bins=30,color = "black", ec="white")
def corrfunc(x, y, dc=False, **kws):
"""different sizes, text anc color in relation to r/d values
the dc parameter determines wheter distance correlation or
linear regression should be applied"""
if dc:
d, p = dist_corr(x,y)
else:
d, p = linreg(x,y)
if d<0.25:
pclr = 'Black'
fontsize = 16
elif (d>=0.25) and (d<0.5):
pclr = 'Blue'
fontsize = 20
elif (d>=0.5) and (p<0.75):
pclr = 'Orange'
fontsize = 25
elif (p>0.75):
pclr = 'Red'
fontsize = 30
if p<0.001:
ptext = "***"
elif (p>=0.001) and (p<0.01):
ptext = "**"
elif (p>=0.01) and (p<0.05):
ptext = "*"
elif (p>0.05):
ptext = "n.sig"
ax = plt.gca()
if dc:
ax.annotate(''.join(['DC: ',str(np.round(d,2)),'\n\n ',ptext]),
xy=(0.3, 0.3),
xycoords=ax.transAxes,
color = pclr,
fontsize = fontsize)
else:
ax.annotate(''.join(['r2: ',str(np.round(d,2)),'\n\n ',ptext]),
xy=(0.3, 0.3),
xycoords=ax.transAxes,
color = pclr,
fontsize = fontsize)
plt.axis('off')
def make_diag_titles(g,titles):
for (i,row) in enumerate(g.axes):
g.axes[i][i].title.set_text(titles[i])
return g
###
# here the plot is put together
plt.figure(num=None, figsize=(cm2inch(15), cm2inch(10)), dpi=300, facecolor='w', edgecolor='k')
g = sns.PairGrid(data, diag_sharey=False)
g.map_lower(scatterfunc)
g.map_diag(histfunc)
g.map_upper(corrfunc)
g = make_diag_titles(g, data.columns)
plt.tight_layout()
plt.show()
########
data = pd.DataFrame(np.random.random([1000,10]),columns=[str(i) for i in range(10)])
for (i,col) in enumerate(data):
if i > 1:
if np.random.random()>0.5:
data[col]= data[col] * data.iloc[:,random.sample(set(np.arange(0,i)),1 )[0]]
corrmat(data)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.