简体   繁体   中英

Annotating seaborn regplot parameters to the plot

I am trying to make scatter plots with r2 , p and rmse values using seaborn.regplot . But following code returns an error of AttributeError: 'AxesSubplot' object has no attribute 'map_dataframe'

fig, axes = plt.subplots(1, 2, figsize=(15, 5), sharey=True)


g = sns.regplot(x='est_fmc', y='1h_surface', data=new_df, ax=axes[0])
def annotate(data, **kws):
    slope, intercept, rvalue, pvalue, stderr = scipy.stats.linregress(x = data['est_fmc'], y= data['1h_surface'] )
    rmse = mean_squared_error(data['est_fmc'], data['1h_surface'], squared=False)
    print(slope, intercept, rvalue, pvalue, rmse)
    ax = plt.gca()
    ax.text(.02, .9, 'r2={:.2f}, p={:.2g}, rmse = {:.2f}'.format(rvalue**2, pvalue, rmse),
            transform=ax.transAxes)
g.map_dataframe(annotate)


g = sns.regplot(x='est_fmc', y='1h_profile', data=new_df, ax = axes[1] )
def annotate(data, **kws):
    slope, intercept, rvalue, pvalue, stderr = scipy.stats.linregress(x = data['est_fmc'], y= data['1h_profile'] )
    rmse = mean_squared_error(data['est_fmc'], data['1h_profile'], squared=False)
    print(slope, intercept, rvalue, pvalue, rmse)
    ax = plt.gca()
    ax.text(.02, .9, 'r2={:.2f}, p={:.2g}, rmse = {:.2f}'.format(rvalue**2, pvalue, rmse),
            transform=ax.transAxes)
g.map_dataframe(annotate)

Is there a way to work around? I would really appreciate any help.

An important aspect of seaborn is the difference between figure-level and axes-level functions . sns.regplot is an axes-level function. It gets an ax (indicating the subplot) as an optional parameter, and always returns the ax on which the plot has been created.

map_dataframe is meant to work with figure-level functions (which create a grid of subplots). It can work together with a function such as relplot . Note that figure-level functions don't accept an ax as parameter, they always create their own new figure.

In your case, you can modify the annotate function with an ax parameter, and also a parameter for x and y to make it work for both subplots. (An important concept in Python is "DRY - Don't Repeat Yourself" .)

Here is the modified code, starting from some test data. (A further improvement would be to also put the call to regplot into the annotate function, renaming that function to something like "regplot_with_annotation").

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import scipy
from sklearn.metrics import mean_squared_error

def annotate(ax, data, x, y):
    slope, intercept, rvalue, pvalue, stderr = scipy.stats.linregress(x=data[x], y=data[y])
    rmse = mean_squared_error(data[x], data[y], squared=False)
    ax.text(.02, .9, f'r2={rvalue ** 2:.2f}, p={pvalue:.2g}, rmse={rmse:.2f}', transform=ax.transAxes)

est_fmc = np.random.uniform(0, 10, 100)
oneh_surface = 2 * est_fmc + np.random.normal(0, 5, 100) + 10
oneh_profile = 3 * est_fmc + np.random.normal(0, 3, 100) + 5

new_df = pd.DataFrame({'est_fmc': est_fmc, '1h_surface': oneh_surface, '1h_profile': oneh_profile})

fig, axes = plt.subplots(1, 2, figsize=(15, 5), sharey=True)

ax = sns.regplot(x='est_fmc', y='1h_surface', data=new_df, ax=axes[0])
annotate(ax, data=new_df, x='est_fmc', y='1h_surface')

ax = sns.regplot(x='est_fmc', y='1h_profile', data=new_df, ax=axes[1])
annotate(ax, data=new_df, x='est_fmc', y='1h_profile')

plt.tight_layout()
plt.show()

sns.regplot 带注释

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM