I am trying to make scatter plots with r2
, p
and rmse
values using seaborn.regplot
. But following code returns an error of AttributeError: 'AxesSubplot' object has no attribute 'map_dataframe'
fig, axes = plt.subplots(1, 2, figsize=(15, 5), sharey=True)
g = sns.regplot(x='est_fmc', y='1h_surface', data=new_df, ax=axes[0])
def annotate(data, **kws):
slope, intercept, rvalue, pvalue, stderr = scipy.stats.linregress(x = data['est_fmc'], y= data['1h_surface'] )
rmse = mean_squared_error(data['est_fmc'], data['1h_surface'], squared=False)
print(slope, intercept, rvalue, pvalue, rmse)
ax = plt.gca()
ax.text(.02, .9, 'r2={:.2f}, p={:.2g}, rmse = {:.2f}'.format(rvalue**2, pvalue, rmse),
transform=ax.transAxes)
g.map_dataframe(annotate)
g = sns.regplot(x='est_fmc', y='1h_profile', data=new_df, ax = axes[1] )
def annotate(data, **kws):
slope, intercept, rvalue, pvalue, stderr = scipy.stats.linregress(x = data['est_fmc'], y= data['1h_profile'] )
rmse = mean_squared_error(data['est_fmc'], data['1h_profile'], squared=False)
print(slope, intercept, rvalue, pvalue, rmse)
ax = plt.gca()
ax.text(.02, .9, 'r2={:.2f}, p={:.2g}, rmse = {:.2f}'.format(rvalue**2, pvalue, rmse),
transform=ax.transAxes)
g.map_dataframe(annotate)
Is there a way to work around? I would really appreciate any help.
An important aspect of seaborn is the difference between figure-level and axes-level functions . sns.regplot
is an axes-level function. It gets an ax
(indicating the subplot) as an optional parameter, and always returns the ax
on which the plot has been created.
map_dataframe
is meant to work with figure-level functions (which create a grid of subplots). It can work together with a function such as relplot
. Note that figure-level functions don't accept an ax
as parameter, they always create their own new figure.
In your case, you can modify the annotate
function with an ax
parameter, and also a parameter for x
and y
to make it work for both subplots. (An important concept in Python is "DRY - Don't Repeat Yourself" .)
Here is the modified code, starting from some test data. (A further improvement would be to also put the call to regplot
into the annotate
function, renaming that function to something like "regplot_with_annotation").
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import scipy
from sklearn.metrics import mean_squared_error
def annotate(ax, data, x, y):
slope, intercept, rvalue, pvalue, stderr = scipy.stats.linregress(x=data[x], y=data[y])
rmse = mean_squared_error(data[x], data[y], squared=False)
ax.text(.02, .9, f'r2={rvalue ** 2:.2f}, p={pvalue:.2g}, rmse={rmse:.2f}', transform=ax.transAxes)
est_fmc = np.random.uniform(0, 10, 100)
oneh_surface = 2 * est_fmc + np.random.normal(0, 5, 100) + 10
oneh_profile = 3 * est_fmc + np.random.normal(0, 3, 100) + 5
new_df = pd.DataFrame({'est_fmc': est_fmc, '1h_surface': oneh_surface, '1h_profile': oneh_profile})
fig, axes = plt.subplots(1, 2, figsize=(15, 5), sharey=True)
ax = sns.regplot(x='est_fmc', y='1h_surface', data=new_df, ax=axes[0])
annotate(ax, data=new_df, x='est_fmc', y='1h_surface')
ax = sns.regplot(x='est_fmc', y='1h_profile', data=new_df, ax=axes[1])
annotate(ax, data=new_df, x='est_fmc', y='1h_profile')
plt.tight_layout()
plt.show()
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.