简体   繁体   中英

Regression fitting incorrectly Matplotlib

I'm trying to add a regression line to this dataset using matplotlib.

Country             GDP per capita  Life satisfaction
Russia              9054.914        6
Turkey              9437.372        5.6
Hungary             12239.894       4.9
Poland              12495.334       5.8
Slovak Republic     15991.736       6.1
Estonia             17288.083       5.6
Greece              18064.288       4.8
Portugal            19121.592       5.1
Slovenia            20732.482       5.7
Spain               25864.721       6.5
Korea               27195.197       5.8
Italy               29866.581       6
Japan               32485.545       5.9
Israel              35343.336       7.4
New Zealand         37044.891       7.3
France              37675.006       6.5
Belgium             40106.632       6.9
Germany             40996.511       7
Finland             41973.988       7.4
Canada              43331.961       7.3
Netherlands         43603.115       7.3
Austria             43724.031       6.9
United Kingdom      43770.688       6.8
Sweden              49866.266       7.2
Iceland             50854.583       7.5
Australia           50961.865       7.3
Ireland             51350.744       7
Denmark             52114.165       7.5
United States       55805.204       7.2

but when I plot the slope and intercept per this example - https://www.statology.org/scatterplot-with-regression-line-python/

%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd

country_stats = pd.read_csv("../data/country_stats.csv")

X = np.c_[country_stats["GDP per capita"]]
y = np.c_[country_stats["Life satisfaction"]]

country_stats.plot(kind='scatter', x="GDP per capita", y='Life satisfaction')
plt.axis([0,60000,0,10])

#obtain m (slope) and b(intercept) of linear regression line
m, b = np.polyfit(X[0], y[0], 1)

plt.plot(X, m*X+b, color='red')

plt.show()

the regression line does not fit the data; resulting in this plot -

在此处输入图像描述

Is there something improper I am doing that that is causing the poor fit?

I know I could address this in a few lines by using seaborn instead:

import seaborn as sns

sns.regplot(X,y,ci=None)

but I'd like to understand the underlying reason for the poor fit.

The problem here is that you only fit your line to one point, that is, the first point X[0], y[0] . So you can just write

m, b = np.polyfit(X[:, 0], y[:, 0], 1)

or more cleanly remove the dimensions you unnecessarily added at the start and write

X = country_stats["GDP per capita"]
y = country_stats["Life satisfaction"]
...
m, b = np.polyfit(X, y, 1)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM