简体   繁体   中英

Linear Regression & Machine Learning Error

Good day devs,

I am currently working on Linear Regression with Machine Learning.

The module sklearn.linear_model method Linear_regresion works just fine but throws an error when I try plotting the graph with matplotlib.pyplot plot() method.

You can find my code below:


import pandas
from pandas import DataFrame
import matplotlib.pyplot as plt

from sklearn.linear_model import LinearRegression

data = pandas.read_csv('cost_revenue.csv')

data.describe()

#The CSV file contains 5034 entries. 

X = DataFrame(data, columns=['production_budget_usd'])

Y = DataFrame(data, columns=['worldwide_gross_usd'])

plt.figure(figsize=(10,6))
plt.scatter(X, Y, alpha=0.3)

plt.title('Film Cost vs Global Revenue')
plt.xlabel('Production Budget $')
plt.ylabel('Worldwide Gross $')

plt.ylim(0, 3000000000)
plt.xlim(0, 450000000)
plt.show()

#This plots a scatterplot and works just fine.

regression = LinearRegression()
regression.fit(X, Y)

plt.figure(figsize=(10,6))
plt.scatter(X, y, alpha=0.3)
plt.plot(X, regression.predict(X), color= 'red', linewidth=3)

plt.title('Film Cost vs Global Revenue')
plt.xlabel('Production budget $')
plt.ylabel('worldwide gross $')
plt.ylim(0,3000000000)
plt.xlim(0,450000000)
plt.show()

#This is the part pf the code where it throws an exception

It is suppose to draw a linear regression line on the graph but it throws 3 errors which. I haven't been able to debug and I will appreciate any possible help.

The errors are: Typeerror Keyerror InvalidIndexerror

Debugging from top to bottom

I believe the problem may be in the way you are passing the "X" and "Y" values to the "scatter" and "plot" methods in matplotlib. Sklearn's LinearRegression model expects the input data to be in the form of 2D arrays, whereas the DataFrame you pass only contains 1 column.

Have you tried converting X and Y into 2D arrays using the DataFrame's values attribute and passing this to the adjustment and prediction methods

X = DataFrame(data, columns=['production_budget_usd']).values
Y = DataFrame(data, columns=['worldwide_gross_usd']).values

regression = LinearRegression()
regression.fit(X, Y)

plt.figure(figsize=(10,6))
plt.scatter(X, Y, alpha=0.3)
plt.plot(X, regression.predict(X), color= 'red', linewidth=3)

One question, is the data file "cost_revenue.csv" loaded correctly, and does it contain the expected data?

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM