[英]Annotating seaborn regplot parameters to the plot
我正在尝试使用seaborn.regplot
制作具有r2
、 p
和rmse
值的散点图。 但以下代码返回AttributeError: 'AxesSubplot' object has no attribute 'map_dataframe'
fig, axes = plt.subplots(1, 2, figsize=(15, 5), sharey=True)
g = sns.regplot(x='est_fmc', y='1h_surface', data=new_df, ax=axes[0])
def annotate(data, **kws):
slope, intercept, rvalue, pvalue, stderr = scipy.stats.linregress(x = data['est_fmc'], y= data['1h_surface'] )
rmse = mean_squared_error(data['est_fmc'], data['1h_surface'], squared=False)
print(slope, intercept, rvalue, pvalue, rmse)
ax = plt.gca()
ax.text(.02, .9, 'r2={:.2f}, p={:.2g}, rmse = {:.2f}'.format(rvalue**2, pvalue, rmse),
transform=ax.transAxes)
g.map_dataframe(annotate)
g = sns.regplot(x='est_fmc', y='1h_profile', data=new_df, ax = axes[1] )
def annotate(data, **kws):
slope, intercept, rvalue, pvalue, stderr = scipy.stats.linregress(x = data['est_fmc'], y= data['1h_profile'] )
rmse = mean_squared_error(data['est_fmc'], data['1h_profile'], squared=False)
print(slope, intercept, rvalue, pvalue, rmse)
ax = plt.gca()
ax.text(.02, .9, 'r2={:.2f}, p={:.2g}, rmse = {:.2f}'.format(rvalue**2, pvalue, rmse),
transform=ax.transAxes)
g.map_dataframe(annotate)
有办法解决吗? 我真的很感激任何帮助。
seaborn 的一个重要方面是图形级和轴级功能之间的区别。 sns.regplot
是轴级别的 function。 它获取一个ax
(指示子图)作为可选参数,并始终返回创建 plot 的ax
。
map_dataframe
旨在与图形级函数(创建子图网格)一起使用。 它可以与 relplot 等relplot
一起使用。 请注意,图形级函数不接受ax
作为参数,它们总是创建自己的新图形。
在您的情况下,您可以使用ax
参数以及x
和y
的参数修改annotate
function 以使其适用于两个子图。 (Python 中的一个重要概念是“干——不要重复自己” 。)
这是修改后的代码,从一些测试数据开始。 (进一步的改进是将调用regplot
放入annotate
function,将 function 重命名为“regplot_with_annotation”)。
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import scipy
from sklearn.metrics import mean_squared_error
def annotate(ax, data, x, y):
slope, intercept, rvalue, pvalue, stderr = scipy.stats.linregress(x=data[x], y=data[y])
rmse = mean_squared_error(data[x], data[y], squared=False)
ax.text(.02, .9, f'r2={rvalue ** 2:.2f}, p={pvalue:.2g}, rmse={rmse:.2f}', transform=ax.transAxes)
est_fmc = np.random.uniform(0, 10, 100)
oneh_surface = 2 * est_fmc + np.random.normal(0, 5, 100) + 10
oneh_profile = 3 * est_fmc + np.random.normal(0, 3, 100) + 5
new_df = pd.DataFrame({'est_fmc': est_fmc, '1h_surface': oneh_surface, '1h_profile': oneh_profile})
fig, axes = plt.subplots(1, 2, figsize=(15, 5), sharey=True)
ax = sns.regplot(x='est_fmc', y='1h_surface', data=new_df, ax=axes[0])
annotate(ax, data=new_df, x='est_fmc', y='1h_surface')
ax = sns.regplot(x='est_fmc', y='1h_profile', data=new_df, ax=axes[1])
annotate(ax, data=new_df, x='est_fmc', y='1h_profile')
plt.tight_layout()
plt.show()
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.