簡體   English   中英

難以在sklearn中擬合多項式回歸曲線

[英]Trouble fitting a polynomial regression curve in sklearn

我是sklearn的新手,我有一個簡單的任務:給定15點的散點圖,我需要

  1. 將其中11個作為我的“培訓樣本”,
  2. 通過這11個點擬合3度的多項式曲線;
  3. 在15個點上繪制所得的多項式曲線。

但是我陷入了第二步。

這是數據圖:

%matplotlib notebook

import numpy as np from sklearn.model_selection 
import train_test_split from sklearn.linear_model 
import LinearRegression from sklearn.preprocessing import PolynomialFeatures

np.random.seed(0) 
n = 15 
x = np.linspace(0,10,n) + np.random.randn(n)/5 
y = np.sin(x)+x/6 + np.random.randn(n)/10

X_train, X_test, y_train, y_test = train_test_split(x, y, random_state=0)

plt.figure() plt.scatter(X_train, y_train, label='training data') 
plt.scatter(X_test, y_test, label='test data') 
plt.legend(loc=4);

然后,我將X_train的11個點轉換為3級的多邊形特征,如下所示:

degrees = 3
poly = PolynomialFeatures(degree=degree)

X_train_poly = poly.fit_transform(X_train)

然后,我嘗試使一條直線穿過轉換后的點(注意: X_train_poly.size = 364)。

linreg = LinearRegression().fit(X_train_poly, y_train)

我收到以下錯誤:

ValueError: Found input variables with inconsistent numbers of samples: [1, 11]

我已經閱讀了解決相似且通常更復雜的問題的各種問題(例如python中的多變量(多項式)最佳擬合曲線? ),但是我無法從中提取解決方案。

問題是X_train和y_train中的尺寸。 它是一維數組,因此將X條記錄中的每條記錄都視為一個單獨的變量。

如下使用.reshape命令應該可以解決問題:

# reshape data to have 11 records rather than 11 columns
X_trainT     = X_train.reshape(11,1)
y_trainT     = y_train.reshape(11,1)

# create polynomial features on the single va
poly         = PolynomialFeatures(degree=3)
X_train_poly = poly.fit_transform(X_trainT)

print (X_train_poly.shape)
# 

linreg       = LinearRegression().fit(X_train_poly, y_trainT)

該錯誤基本上意味着您的X_train_polyy_train不匹配,其中您的X_train_poly只有1套x,而y_train有11個值。 我不太確定您想要什么,但是我想多項式特征不是按照您想要的方式生成的。 您的代碼當前正在執行的操作是為單個11維點生成3級多項式特征。

我想您想為11個點中的每個點(實際上是每個x)生成3度多項式特征。 您可以使用循環或列表理解來做到這一點:

X_train_poly = poly.fit_transform([[i] for i in X_train])
X_train_poly.shape
# (11, 4)

現在,您可以看到X_train_poly有11個點,每個點是4維的,而不是單個364維的點。 這種新X_train_poly的形狀相匹配y_train和回歸可能會給你想要的東西:

linreg = LinearRegression().fit(X_train_poly, y_train)
linreg.coef_
# array([ 0.        , -0.79802899,  0.2120088 , -0.01285893])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM