简体   繁体   中英

How to fix the code to predict Y based on X1 and X2

I have the following data:

X1   X2   Y
-10  4    0
-10  3    4
-10  2.5  8
-8   3    7
-8   4    8
-8   4.4  9
0    2    9
0    2.3  9.2
0    4    10
0    5    12

I need to create a simple regression model to predict Y given X1 and X2: Y = f(X1,X2).

This is my code:

poly = PolynomialFeatures(degree=2)
X1 = poly.fit_transform(df["X1"].values.reshape(-1,1))
X2 = poly.fit_transform(df["X2"].values.reshape(-1,1))
clf = linear_model.LinearRegression()
clf.fit([X1,X2], df["Y"].values.reshape(-1, 1))
print(clf.coef_)
print(clf.intercept_)

Y_test = clf.predict([X1, X2])
df_test=pd.DataFrame()
df_test["X1"] = df["X1"]
df_test["Y"] = df["Y"]
df_test["Y_PRED"] = Y_test

df_test.plot(x="X1",y=["Y","Y_PRED"], figsize=(10,5), grid=True)
plt.show()

But it fails at line clf.fit([X1,X2], df["Y"].values.reshape(-1, 1)) :

ValueError: Found array with dim 3. Estimator expected <= 2

It looks like the model cannot work with 2 input parameters X1 and X2. How should I change the code to fix it?

Well, your mistake resides in the way you append your feature dataframes. You should instead concatenate them, for instance using pandas:

import pandas as pd

X12_p = pd.concat([pd.DataFrame(X1), pd.DataFrame(X2)], axis=1)

Or the same using numpy:

import numpy as np

X12_p = np.concatenate([X1, X2], axis=1)

Your final snippet should look like:

# Fit
Y = df["Y"].values.reshape(-1,1)
X12_p = pd.concat([pd.DataFrame(X1), pd.DataFrame(X2)], axis=1)
clf.fit(X12_p, Y)

# Predict
Y_test = clf.predict(X12_p)

You can as well evaluate some performance metrics such as rmse using:

from sklearn.metrics import mean_squared_error
print('rmse = {0:.5f}'.format(mean_squared_error(Y, Y_test)))

Please also note that you can exclude the bias term from polynomial features by changing the default param:

PolynomialFeatures(degree=2, include_bias=False)

Hope this helps.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM