Good day devs,
I am currently working on Linear Regression with Machine Learning.
The module sklearn.linear_model method Linear_regresion works just fine but throws an error when I try plotting the graph with matplotlib.pyplot plot() method.
You can find my code below:
import pandas
from pandas import DataFrame
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
data = pandas.read_csv('cost_revenue.csv')
data.describe()
#The CSV file contains 5034 entries.
X = DataFrame(data, columns=['production_budget_usd'])
Y = DataFrame(data, columns=['worldwide_gross_usd'])
plt.figure(figsize=(10,6))
plt.scatter(X, Y, alpha=0.3)
plt.title('Film Cost vs Global Revenue')
plt.xlabel('Production Budget $')
plt.ylabel('Worldwide Gross $')
plt.ylim(0, 3000000000)
plt.xlim(0, 450000000)
plt.show()
#This plots a scatterplot and works just fine.
regression = LinearRegression()
regression.fit(X, Y)
plt.figure(figsize=(10,6))
plt.scatter(X, y, alpha=0.3)
plt.plot(X, regression.predict(X), color= 'red', linewidth=3)
plt.title('Film Cost vs Global Revenue')
plt.xlabel('Production budget $')
plt.ylabel('worldwide gross $')
plt.ylim(0,3000000000)
plt.xlim(0,450000000)
plt.show()
#This is the part pf the code where it throws an exception
It is suppose to draw a linear regression line on the graph but it throws 3 errors which. I haven't been able to debug and I will appreciate any possible help.
The errors are: Typeerror Keyerror InvalidIndexerror
Debugging from top to bottom
I believe the problem may be in the way you are passing the "X" and "Y" values to the "scatter" and "plot" methods in matplotlib. Sklearn's LinearRegression model expects the input data to be in the form of 2D arrays, whereas the DataFrame you pass only contains 1 column.
Have you tried converting X and Y into 2D arrays using the DataFrame's values attribute and passing this to the adjustment and prediction methods
X = DataFrame(data, columns=['production_budget_usd']).values
Y = DataFrame(data, columns=['worldwide_gross_usd']).values
regression = LinearRegression()
regression.fit(X, Y)
plt.figure(figsize=(10,6))
plt.scatter(X, Y, alpha=0.3)
plt.plot(X, regression.predict(X), color= 'red', linewidth=3)
One question, is the data file "cost_revenue.csv" loaded correctly, and does it contain the expected data?
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.