简体   繁体   中英

How to plot SciKit-Learn linear regression graph

I am new to SciKit-Learn and I have been working on a regression problem (king county csv) on kaggle. I have been training a regression model to predict the price of the house and I wanted to plot the graph but I have no idea how to do so. I am using python 3.6. Any advice or suggestion would be greatly appreciated.

#importing numpy and pandas, seaborn

import numpy as np #linear algebra
import pandas as pd #datapreprocessing, CSV file I/O
import seaborn as sns #for plotting graphs
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt

data = pd.read_csv('kc_house_data.csv')
data = data.drop('date',axis=1)
data = data.drop('id',axis=1)

X = data
Y = X['price'].values
X = X.drop('price', axis = 1).values

X_train, X_test, Y_train, Y_test = train_test_split (X, Y, test_size = 0.30, random_state=21)


reg = LinearRegression()
kfold = KFold(n_splits=15, random_state=21)
cv_results = cross_val_score(reg, X_train, Y_train, cv=kfold, scoring='r2')

print(cv_results)

round(np.mean(cv_results)*100, 2)

You can use matplotlib for plotting

import matplotlib.pyplot as plt
plt.figure(figsize=(16, 9))
plt.plot(cv_results)
    
plt.show()

There can be multiple type of plots you can use like simple line plot or scatter plot.

plt.barh(x, y) # for bar graph
plt.plot(x,y)  # for line graph
plt.scatter(x,y) # for scatter graph

Seaborn is a very useful visualization library. So much so that you can use 'seaborn.regplot' to directly plot the data and regression-model-fit line. It directly takes in the predictor variable and response variable, and spits out the plot of data points and best fit line. Here is the link on how to use it:

https://seaborn.pydata.org/generated/seaborn.regplot.html

I have also done the same competition on kaggle. For regressions I would go for a scatter plot:

import matplotlib as plt
plt.plot(x,y)

As for the visualisations on that particular competition I would use the following code:

# visualising some more outliers in the data values
fig, axs = plt.subplots(ncols=2, nrows=0, figsize=(12, 120))
plt.subplots_adjust(right=2)
plt.subplots_adjust(top=2)
sns.color_palette("husl", 8)
for i, feature in enumerate(list(train[numeric]), 1):
if(feature=='MiscVal'):
    break
plt.subplot(len(list(numeric)), 3, i)
sns.scatterplot(x=feature, y='SalePrice', hue='SalePrice', palette='Blues', data=train)
    
plt.xlabel('{}'.format(feature), size=15,labelpad=12.5)
plt.ylabel('SalePrice', size=15, labelpad=12.5)

for j in range(2):
    plt.tick_params(axis='x', labelsize=12)
    plt.tick_params(axis='y', labelsize=12)

plt.legend(loc='best', prop={'size': 10})
    
plt.show()

I have actually uploaded the full code for that competition on my GitHub if you want to have a look;) (I am currently in the top 14% on that competition).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM