简体   繁体   中英

Is there a way to get the x-axis and y-axis values of my seaborn plot?

I used seaborn library to get fit a regression line for my data. Then I also plotted the residual plot. I now need to see the histogram distribution of my residuals? How can I do that as I don't have the values plotted in the graph.

Here is my code:

fig,axes = plt.subplots(1,3,figsize=(15,5))
sns.regplot(x = 'Radio',y='Sales',data=df_advertising,ax = axes[0])
sns.residplot(x = 'Radio',y='Sales',data=df_advertising,ax = axes[1])

How can I get the values of my residual plot so that I can plot the corresponding histogram to see the distribution.

Thanks, any help will be appreciated. I'm just a beginner.

It's not quite possible to get back the fit or the values (see also this question ). It also makes sense that you know what you fit, and then plot the residuals. Below I use an example dataset:

import matplotlib. pyplot as plt
import seaborn as sns
import numpy as np
import statsmodels.api as sm

df_advertising = pd.DataFrame({'Radio':np.random.randint(1,10,100)})
df_advertising['Sales'] = 3*df_advertising['Radio'] + np.random.normal(10,4,100)

We can plot it using seaborn:

fig,axes = plt.subplots(1,2,figsize=(10,5))
sns.regplot(x = 'Radio',y='Sales',data=df_advertising,ax = axes[0])
sns.residplot(x = 'Radio',y='Sales',data=df_advertising,ax = axes[1])

在此处输入图像描述

seaborn uses statsmodels, so lets use that to fit and get the predictions:

mod = sm.OLS(df_advertising['Sales'],sm.add_constant(df_advertising['Radio']))
res = mod.fit()

test = df_advertising[['Radio']].drop_duplicates().sort_values('Radio')
predictions = res.get_prediction(sm.add_constant(test))
predictions = pd.concat([test,predictions.summary_frame(alpha=0.05)],axis=1)
predictions.head()

    Radio   mean    mean_se mean_ci_lower   mean_ci_upper   obs_ci_lower    obs_ci_upper
13  1   11.132902   0.700578    9.742628    12.523175   3.862061    18.403742
6   2   14.480520   0.582916    13.323742   15.637298   7.250693    21.710347
2   3   17.828139   0.478925    16.877728   18.778550   10.628448   25.027829
4   4   21.175757   0.399429    20.383104   21.968411   13.995189   28.356326
10  5   24.523376   0.360990    23.807002   25.239750   17.350827   31.695924

In the above, I create test to not duplicate the data points (since mine was counts). Now we have everything to plot. The residuals are simply under resid of the statsmodels object:

fig,axes = plt.subplots(1,3,figsize=(15,5))

sns.scatterplot(x='Radio',y='Sales',ax=axes[0],data=df_advertising)
axes[0].plot(predictions['Radio'], predictions['mean'], lw=2)
axes[0].fill_between(x=predictions['Radio'],
                     y1=predictions['mean_ci_lower'],y2=predictions['mean_ci_upper'],
                     facecolor='blue', alpha=0.2)

sns.scatterplot(x='Radio',y='Sales',ax=axes[1],
                data=pd.DataFrame({'Radio':df_advertising['Radio'],
                                  'Sales':res.resid})
               )
axes[1].axhline(0, ls='--',color="k")

sns.distplot(res.resid,ax=axes[2],bins=20)

在此处输入图像描述

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM