簡體   English   中英

提高機器學習的准確性 Python 中的 model 預測

[英]Improving accuracy of machine learning model predictions in Python

我們目前正在為一家本地公司在 Python 中實施 ML model,以預測 0-999 點范圍內的信用評分。 從數據庫中提取了 11 個自變量(信用歷史和支付行為)和一個因變量(信用評分)。 客戶表示,生產 model 的 MAE 必須小於 100 點才能使用。 問題是我們已經嘗試了幾種算法來實現這種回歸,但我們的模型無法很好地概括看不見的數據。 到目前為止,性能最好的算法似乎是隨機森林,但它在測試數據上的 MAE 仍然超出了可接受的值。 這是我們的代碼:

import numpy as np
from sklearn.preprocessing import RobustScaler
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Ridge
from sklearn.linear_model import Lasso
from sklearn.linear_model import ElasticNet
from sklearn.ensemble import RandomForestRegressor
from sklearn import metrics
from keras.layers import Dense
from keras.models import Sequential

# Linear Model
def GetLinearModel(X, y):
    model = LinearRegression()
    model.fit(X, y)
    return model   

# Ridge Regression
def GetRidge(X, y):
    model = Ridge(alpha=0.01)
    model.fit(X_train, y_train) 
    return model

# LASSO Regression
def GetLASSO(X, y):
    model = Lasso(alpha=0.01)
    model.fit(X_train, y_train) 
    return model

# ElasticNet Regression
def GetElasticNet(X, y):
    model = ElasticNet(alpha=0.01)
    model.fit(X_train, y_train) 
    return model

# Random Forest
def GetRandomForest(X, y):
    model = RandomForestRegressor(n_estimators=32, random_state=0)
    model.fit(X, y)
    return model

# Neural Networks
def GetNeuralNetworks(X, y):
    model = Sequential()
    model.add(Dense(32, activation = 'relu', input_dim = 11))
    model.add(Dense(units = 32, activation = 'relu'))
    model.add(Dense(units = 32, activation = 'relu'))
    model.add(Dense(units = 32, activation = 'relu'))
    model.add(Dense(units = 32, activation = 'relu'))
    model.add(Dense(units = 1))
    model.compile(optimizer = 'adam', loss = 'mean_absolute_error')
    model.fit(X, y, batch_size = 100, epochs = 500, verbose=0)
    return model

# Train data
train_set = np.array([\
[2, 5, 9, 28, 0, 0.153668, 500, 0, 0, 0.076923077, 0, 800],\
[3, 0, 0, 42, 2, 0.358913, 500, 0, 0, 0.230769231, 0, 900],\
[3, 0, 0, 12, 2, 0, 500, 0, 0, 0.076923077, 0, 500],\
[1, 0, 0, 6, 1, 0.340075, 457, 0, 0, 0.076923077, 0, 560],\
[1, 5, 0, 12, 3, 0.458358, 457, 0, 0, 0.153846154, 0, 500],\
[1, 3, 4, 32, 2, 0.460336, 457, 0, 0, 0.153846154, 0, 600],\
[3, 0, 0, 42, 4, 0.473414, 500, 0, 0, 0.230769231, 0, 700],\
[1, 3, 0, 16, 0, 0.332991, 500, 0, 0, 0.076923077, 0, 600],\
[1, 3, 19, 27, 0, 0.3477, 500, 0, 0, 0.076923077, 0, 580],\
[1, 5, 20, 74, 1, 0.52076, 500, 0, 0, 0.230769231, 0, 550],\
[6, 0, 0, 9, 3, 0, 500, 0, 0, 0.076923077, 0, 570],\
[1, 8, 47, 0, 0, 0.840656, 681, 0, 0, 0, 0, 50],\
[1, 0, 0, 8, 14, 0, 681, 0, 0, 0.076923077, 0, 400],\
[5, 6, 19, 7, 1, 0.251423, 500, 0, 1, 0.076923077, 1, 980],\
[1, 0, 0, 2, 2, 0.121852, 500, 1, 0, 0.076923077, 9, 780],\
[2, 0, 0, 4, 0, 0.37242, 500, 1, 0, 0.076923077, 0, 920],\
[3, 4, 5, 20, 0, 0.37682, 500, 1, 0, 0.076923077, 0, 700],\
[3, 8, 17, 20, 0, 0.449545, 500, 1, 0, 0.076923077, 0, 300],\
[3, 12, 30, 20, 0, 0.551193, 500, 1, 0, 0.076923077, 0, 30],\
[0, 1, 10, 8, 3, 0.044175, 500, 0, 0, 0.076923077, 0, 350],\
[1, 0, 0, 14, 3, 0.521714, 500, 0, 0, 0.153846154, 0, 650],\
[2, 4, 15, 0, 0, 0.985122, 500, 0, 0, 0, 0, 550],\
[2, 4, 34, 0, 0, 0.666666, 500, 0, 0, 0, 0, 600],\
[1, 16, 17, 10, 3, 0.299756, 330, 0, 0, 0.153846154, 0, 650],\
[2, 0, 0, 16, 1, 0, 500, 0, 0, 0.076923077, 0, 900],\
[2, 5, 31, 26, 0, 0.104847, 500, 0, 0, 0.076923077, 0, 850],\
[2, 6, 16, 34, 1, 0.172947, 500, 0, 0, 0.153846154, 0, 900],\
[1, 4, 0, 16, 6, 0.206403, 500, 0, 0, 0.153846154, 0, 630],\
[1, 8, 20, 12, 5, 0.495897, 500, 0, 0, 0.153846154, 0, 500],\
[1, 8, 46, 8, 6, 0.495897, 500, 0, 0, 0.153846154, 0, 250],\
[2, 0, 0, 4, 8, 0, 500, 0, 0, 0.076923077, 0, 550],\
[2, 6, 602, 0, 0, 0, 500, 0, 0, 0, 0, 20],\
[0, 12, 5, 21, 0, 0.158674, 645, 0, 0, 0.153846154, 0, 850],\
[0, 12, 20, 21, 0, 0.158674, 645, 0, 0, 0.153846154, 0, 700],\
[1, 0, 0, 33, 0, 0.041473, 645, 0, 0, 0.230769231, 0, 890],\
[1, 0, 0, 12, 2, 0.147325, 500, 0, 0, 0.076923077, 0, 780],\
[1, 8, 296, 0, 0, 2.891695, 521, 0, 0, 0, 0, 1],\
[1, 0, 0, 4, 0, 0.098953, 445, 0, 0, 0.076923077, 0, 600],\
[1, 0, 0, 4, 0, 0.143443, 500, 0, 0, 0.076923077, 0, 500],\
[0, 8, 20, 0, 0, 1.110002, 833, 0, 0, 0, 0, 100],\
[0, 0, 0, 8, 2, 0, 833, 0, 0, 0.076923077, 0, 300],\
[1, 4, 60, 20, 6, 0.78685, 833, 0, 0, 0.153846154, 0, 100],\
[1, 4, 112, 20, 6, 0.78685, 833, 0, 0, 0.153846154, 0, 1],\
[1, 0, 0, 21, 10, 0.305556, 500, 0, 0, 0.307692308, 0, 150],\
[1, 0, 0, 21, 10, 0.453743, 500, 0, 0, 0.307692308, 0, 300],\
[0, 0, 0, 8, 0, 0, 570, 0, 0, 0, 0, 500],\
[0, 10, 10, 8, 0, 0.325975, 570, 0, 0, 0.076923077, 0, 450],\
[1, 7, 16, 15, 1, 0.266311, 570, 0, 0, 0.076923077, 0, 450],\
[1, 1, 32, 30, 4, 0.134606, 570, 0, 0, 0.230769231, 0, 250],\
[1, 0, 0, 32, 5, 0.105576, 570, 0, 0, 0.230769231, 0, 430],\
[1, 4, 34, 32, 5, 0.519103, 500, 0, 0, 0.230769231, 0, 350],\
[1, 0, 0, 12, 1, 0.109559, 669, 0, 0, 0.076923077, 0, 600],\
[11, 4, 15, 2, 3, 0.235709, 500, 0, 1, 0, 2, 900],\
[11, 4, 15, 1, 6, 0.504134, 500, 0, 1, 0, 2, 534],\
[2, 0, 0, 15, 9, 0.075403, 500, 0, 0, 0.076923077, 0, 573],\
[10, 0, 0, 51, 11, 2.211951, 500, 0, 0, 0.307692308, 7, 547],\
[9, 0, 0, 28, 4, 0.328037, 500, 0, 0, 0.230769231, 0, 747],\
[9, 2, 0, 0, 0, 0.166666, 500, 0, 1, 0.076923077, 4, 448],\
[8, 0, 0, 4, 1, 0, 500, 0, 1, 0, 1, 719],\
[3, 4, 15, 8, 1, 0.150237, 500, 0, 1, 0, 0, 827],\
[7, 138, 35, 37, 1, 0.414154, 500, 0, 1, 0.076923077, 3, 950],\
[6, 19, 41, 84, 1, 0.41248, 500, 0, 0, 0.230769231, 0, 750],\
[1, 6, 10, 0, 0, 0.232647, 500, 0, 1, 0, 0, 700],\
[0, 10, 27, 0, 0, 0.411712, 4, 0, 0, 0, 0, 520],\
[3, 31, 45, 80, 0, 0.266299, 500, 0, 0, 0.153846154, 0, 750],\
[3, 24, 49, 2, 1, 0.981102, 500, 0, 0, 0.076923077, 0, 550],\
[1, 12, 31, 11, 1, 0.333551, 500, 0, 0, 0.153846154, 0, 500],\
[0, 18, 30, 13, 2, 0.602826, 406, 0, 0, 0.076923077, 0, 580],\
[2, 2, 31, 0, 0, 1, 500, 0, 0, 0, 0, 427],\
[1, 18, 40, 83, 1, 0.332792, 500, 0, 0, 0.307692308, 0, 485],\
[2, 14, 35, 9, 3, 0.39671, 500, 0, 1, 0.076923077, 3, 664],\
[2, 88, 32, 7, 2, 0.548066, 500, 0, 1, 0, 1, 90],\
[2, 26, 26, 32, 2, 0.415991, 500, 0, 0, 0.153846154, 0, 90],\
[1, 14, 30, 11, 1, 0.51743, 599, 0, 0, 0.153846154, 0, 300],\
[1, 15, 28, 26, 0, 0.4413, 500, 0, 0, 0.076923077, 0, 610],\
[1, 17, 50, 34, 1, 0.313789, 500, 0, 0, 0.230769231, 0, 450],\
[0, 4, 15, 0, 0, 0.535163, 500, 0, 0, 0, 0, 375],\
[0, 8, 23, 0, 0, 0.51242, 500, 0, 0, 0, 0, 550],\
[3, 6, 44, 2, 3, 0.268062, 500, 0, 1, 0, 2, 744],\
[6, 38, 51, 35, 0, 0.28396, 500, 0, 1, 0.076923077, 1, 980],\
[6, 5, 63, 6, 5, 0.566661, 500, 0, 0, 0.153846154, 0, 850],\
[6, 0, 0, 0, 0, 0.174852, 500, 0, 0, 0, 0, 800],\
[6, 4, 60, 6, 3, 0.517482, 500, 0, 0, 0.076923077, 0, 750],\
[5, 16, 52, 49, 4, 0.378441, 500, 0, 1, 0.153846154, 6, 720],\
[5, 26, 84, 103, 1, 0.472361, 500, 0, 0, 0.230769231, 0, 300],\
[1, 6, 34, 36, 1, 0.298553, 500, 0, 1, 0.153846154, 0, 628],\
[5, 6, 65, 34, 0, 0.301907, 500, 0, 0, 0.153846154, 0, 710],\
[3, 16, 177, 29, 10, 0.501831, 500, 1, 0, 0.153846154, 0, 40],\
[2, 5, 45, 0, 0, 0.351668, 500, 0, 0, 0, 0, 708],\
[2, 7, 57, 7, 4, 0.432374, 500, 0, 0, 0.153846154, 0, 753],\
[1, 1, 75, 36, 0, 0.154085, 500, 0, 0, 0.076923077, 0, 610],\
[1, 16, 63, 13, 2, 0.331244, 500, 0, 0, 0.076923077, 0, 620],\
[1, 3, 55, 9, 0, 0.377253, 500, 0, 0, 0.076923077, 0, 640],\
[1, 1, 75, 5, 5, 0.877696, 500, 0, 0, 0.076923077, 0, 480],\
[1, 0, 0, 8, 5, 0.208742, 500, 0, 0, 0.153846154, 0, 520],\
[1, 3, 55, 29, 0, 0.228812, 678, 0, 0, 0.153846154, 0, 547],\
[1, 0, 0, 2, 2, 0.090459, 553, 0, 0, 0.076923077, 0, 535],\
[0, 4, 29, 0, 0, 0.292161, 500, 0, 0, 0, 0, 594],\
[1, 3, 64, 18, 6, 0.602431, 500, 0, 0, 0.230769231, 0, 500],\
[6, 9, 40, 74, 0, 0.567179, 500, 0, 0, 0.076923077, 0, 910],\
[4, 10, 65, 14, 1, 0.423915, 500, 0, 1, 0, 1, 713],\
[2, 0, 0, 6, 1, 0.114637, 500, 0, 0, 0.076923077, 0, 650],\
[5, 18, 74, 34, 0, 0.489314, 500, 0, 0, 0.153846154, 0, 500],\
[0, 6, 43, 9, 15, 0.599918, 612, 0, 0, 0.153846154, 0, 100],\
[4, 25, 64, 135, 0, 0.472659, 500, 0, 0, 0.230769231, 0, 560],\
[6, 3, 94, 12, 10, 0.31713, 500, 0, 0, 0.230769231, 0, 580],\
[1, 4, 69, 18, 9, 0.412528, 500, 0, 0, 0.307692308, 0, 362],\
[2, 21, 58, 21, 0, 0.53184, 500, 0, 0, 0.153846154, 0, 370],\
[0, 0, 0, 21, 4, 0.033438, 500, 0, 0, 0.153846154, 0, 500],\
[0, 10, 53, 20, 0, 0.619595, 500, 0, 0, 0.076923077, 0, 200],\
[2, 15, 63, 28, 2, 0.593453, 500, 0, 0, 0.153846154, 0, 574],\
[3, 2, 84, 21, 1, 0.302636, 500, 0, 0, 0.153846154, 0, 790],\
[4, 19, 47, 28, 0, 0.256892, 500, 0, 0, 0.076923077, 0, 748],\
[1, 0, 0, 0, 0, 0.119599, 500, 0, 0, 0, 0, 517],\
[3, 10, 53, 22, 0, 0.419703, 500, 0, 0, 0.153846154, 0, 800],\
[4, 7, 66, 70, 1, 0.362268, 500, 0, 0, 0.230769231, 0, 550],\
[0, 16, 88, 18, 3, 0.597145, 16, 0, 0, 0.153846154, 0, 50],\
[5, 8, 38, 0, 0, 0.666666, 500, 0, 0, 0, 0, 667]])

# Test data    
test_set = np.array([\
[2, 16, 87, 30, 0, 0.168057, 500, 0, 1, 0.153846154, 1, 760],\
[3, 5, 83, 6, 4, 0.273522, 500, 0, 0, 0.076923077, 0, 877],\
[1, 0, 0, 12, 0, 0.262797, 500, 0, 0, 0.153846154, 0, 596],\
[2, 15, 46, 28, 0, 0.495495, 500, 0, 0, 0.076923077, 0, 680],\
[1, 0, 0, 22, 9, 0.254813, 500, 0, 0, 0.230769231, 0, 450],\
[3, 19, 59, 12, 0, 0.437851, 500, 0, 0, 0.153846154, 0, 850],\
[4, 5, 28, 0, 0, 0.34559, 500, 0, 1, 0.076923077, 1, 800],\
[1, 5, 58, 0, 0, 0.385379, 500, 0, 0, 0, 0, 641],\
[1, 4, 65, 15, 1, 0.2945, 500, 0, 0, 0.153846154, 0, 644],\
[0, 0, 0, 9, 3, 0.421612, 500, 0, 0, 0.076923077, 0, 580],\
[3, 31, 83, 2, 2, 0.436883, 500, 0, 0, 0.076923077, 0, 410],\
[0, 0, 0, 18, 5, 0.044898, 377, 0, 0, 0.230769231, 0, 520],\
[0, 8, 49, 12, 3, 0.428529, 500, 0, 1, 0.076923077, 1, 370],\
[0, 22, 89, 2, 1, 0.819431, 500, 0, 0, 0.076923077, 0, 440],\
[3, 27, 63, 124, 0, 0.375306, 500, 0, 0, 0.076923077, 0, 880],\
[3, 20, 64, 18, 5, 0.439412, 500, 0, 1, 0.076923077, 3, 820],\
[1, 6, 34, 2, 12, 0.495654, 500, 0, 0, 0.076923077, 0, 653],\
[0, 14, 225, 0, 0, 1, 486, 0, 0, 0, 0, 1],\
[2, 8, 87, 32, 1, 0.829792, 500, 0, 0, 0.230769231, 0, 570],\
[2, 15, 46, 24, 4, 0.500442, 500, 0, 0, 0.153846154, 0, 568]])

# split datasets into independent and dependent variables
X_train, y_train = train_set[:, :-1], train_set[:, -1]    
X_test, y_test = test_set[:, :-1], test_set[:, -1]    

# feature scaling
sc = RobustScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.fit_transform(X_test)

# Linear model
reg = GetLinearModel(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("Linear", mae))

# Ridge Regression
reg = GetRidge(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("Ridge", mae))

# LASSO Regression
reg = GetLASSO(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("LASSO", mae))

# ElasticNet Regression
reg = GetElasticNet(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("ElasticNet", mae))

# Random Forest
reg = GetRandomForest(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("Random Forest", mae))

# Neural networks
reg = GetNeuralNetworks(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("Neural Networks", mae))

Output:

         Linear: 141.265089
          Ridge: 141.267797
          LASSO: 141.274700
     ElasticNet: 141.413544
  Random Forest: 102.701562
WARNING:tensorflow:11 out of the last 11 calls to <function Model.make_predict_function.<locals>.predict_function at 0x00000229766694C0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Neural Networks: 122.301840

任何有關如何提高 model 精度的幫助將不勝感激。

親切的問候。

我正在使用您在示例中提供的數據集我還創建了訓練、驗證和測試數據集,以避免@Prayson W. Daniel 提到的數據泄漏

對於神經網絡,您需要確保標簽和特征都被縮放。您可以將 go 用於標准標量。 您還需要確保特征和標簽必須是 2 暗。在您的示例中,您的 label 是一維數組。

使用以下代碼提取二維特征

Train_labels=train_set[:,[-1]]

您可以使用 StandardScaler 規范化數據,您需要確保標簽和特征都需要規范化

現在,一旦您構建 ANN,您需要確保您的網絡可以看到大量數據,因為您的訓練和測試非常少,您可以使用 K 折交叉驗證 go 我現在不使用 k 折,但我正在創建 model

from keras import regularizers
def build_model() :
    Model=K.models.Sequential()
    Model.add(K.layers.Dense(units=21,activation='relu',
              kernel_regularizer=regularizers.l2(0.001),input_dim=11))
    Model.add(K.layers.Dropout(0.2))
    Model.add(K.layers.Dense(21,activation='relu',
              kernel_regularizer=regularizers.l2(0.001)))
    Model.add(K.layers.Dropout(0.2))
    Model.add(K.layers.Dense(21,activation='relu'))
    Model.add(K.layers.Dense(1))

    #Compile the model


    Optimizer=K.optimizers.Nadam()
    Model.compile(optimizer=Optimizer,loss='mae',metrics=r2_keras_custom)
    return Model


model=build_model()
history=model.fit(x=X_train,y=Y_train,epochs=200,batch_size=29,validation_data= 
(X_test,Y_test))

I am using R2 as custom metric,you can also create one 

這里我使用的是 r2,它是 1-RSS/TSS

plt.plot(history.history['val_r2_keras_custom'])
plt.plot(history.history['r2_keras_custom'])
plt.legend(['Test_score','Train_score'])
plt.plot()

在此處輸入圖像描述

最終成績

我希望這會有所幫助,其他可以糾正我

如果這是整個數據集,那么它很小。 要考慮的一種選擇是研究交叉驗證,而不是將數據拆分為訓練和驗證(AKA 測試)。 交叉驗證是一種針對小型數據集的方法,其中所有數據都用於訓練和驗證,但仍可防止過度擬合。

您可以為每個 model 和交叉驗證執行超參數調整。

這個 class 可以幫助您做到這一點: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.ZFC35FDC70D5FC69D269883A822C7A53E

GridSearchCV 也與 Keras model 兼容。 為此,您可以查看: https://machinelearningmastery.com/grid-search-hyperparameters-deep-learning-models-python-keras/

就個人而言,訓練數據集中的少量記錄意味着訓練機器學習算法集合中的少量基分類器。 檢查您的代碼,我之前沒有使用過 RobustScaler,但我會在測試數據集上使用轉換,而不是 fit_transform。

回到您的代碼,看起來隨機森林的准確性最高。 通過超調一些參數,包括估計器的數量和 max_depth,可以報告更好的性能。 此后,正如其他答案/評論所推薦的那樣,此處需要對算法參數進行超調。

# -*- coding: utf-8 -*-
"""
Created on Wed Jan  6 20:50:44 2021

@author: AliHaidar
"""

import numpy as np
from sklearn.preprocessing import RobustScaler
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Ridge
from sklearn.linear_model import Lasso
from sklearn.linear_model import ElasticNet
from sklearn.ensemble import RandomForestRegressor,GradientBoostingRegressor,AdaBoostRegressor
from sklearn import metrics

from xgboost import XGBRegressor


# Linear Model
def GetLinearModel(X, y):
    model = LinearRegression()
    model.fit(X, y)
    return model   

# Ridge Regression
def GetRidge(X, y):
    model = Ridge(alpha=0.01)
    model.fit(X_train, y_train) 
    return model

# LASSO Regression
def GetLASSO(X, y):
    model = Lasso(alpha=0.01)
    model.fit(X_train, y_train) 
    return model

# ElasticNet Regression
def GetElasticNet(X, y):
    model = ElasticNet(alpha=0.01)
    model.fit(X_train, y_train) 
    return model

# Random Forest
def GetRandomForest(X, y):
    model = RandomForestRegressor(n_estimators=4, random_state=0,max_depth=11)
    model.fit(X, y)
    return model


# Train data
train_set = np.array([\
[2, 5, 9, 28, 0, 0.153668, 500, 0, 0, 0.076923077, 0, 800],\
[3, 0, 0, 42, 2, 0.358913, 500, 0, 0, 0.230769231, 0, 900],\
[3, 0, 0, 12, 2, 0, 500, 0, 0, 0.076923077, 0, 500],\
[1, 0, 0, 6, 1, 0.340075, 457, 0, 0, 0.076923077, 0, 560],\
[1, 5, 0, 12, 3, 0.458358, 457, 0, 0, 0.153846154, 0, 500],\
[1, 3, 4, 32, 2, 0.460336, 457, 0, 0, 0.153846154, 0, 600],\
[3, 0, 0, 42, 4, 0.473414, 500, 0, 0, 0.230769231, 0, 700],\
[1, 3, 0, 16, 0, 0.332991, 500, 0, 0, 0.076923077, 0, 600],\
[1, 3, 19, 27, 0, 0.3477, 500, 0, 0, 0.076923077, 0, 580],\
[1, 5, 20, 74, 1, 0.52076, 500, 0, 0, 0.230769231, 0, 550],\
[6, 0, 0, 9, 3, 0, 500, 0, 0, 0.076923077, 0, 570],\
[1, 8, 47, 0, 0, 0.840656, 681, 0, 0, 0, 0, 50],\
[1, 0, 0, 8, 14, 0, 681, 0, 0, 0.076923077, 0, 400],\
[5, 6, 19, 7, 1, 0.251423, 500, 0, 1, 0.076923077, 1, 980],\
[1, 0, 0, 2, 2, 0.121852, 500, 1, 0, 0.076923077, 9, 780],\
[2, 0, 0, 4, 0, 0.37242, 500, 1, 0, 0.076923077, 0, 920],\
[3, 4, 5, 20, 0, 0.37682, 500, 1, 0, 0.076923077, 0, 700],\
[3, 8, 17, 20, 0, 0.449545, 500, 1, 0, 0.076923077, 0, 300],\
[3, 12, 30, 20, 0, 0.551193, 500, 1, 0, 0.076923077, 0, 30],\
[0, 1, 10, 8, 3, 0.044175, 500, 0, 0, 0.076923077, 0, 350],\
[1, 0, 0, 14, 3, 0.521714, 500, 0, 0, 0.153846154, 0, 650],\
[2, 4, 15, 0, 0, 0.985122, 500, 0, 0, 0, 0, 550],\
[2, 4, 34, 0, 0, 0.666666, 500, 0, 0, 0, 0, 600],\
[1, 16, 17, 10, 3, 0.299756, 330, 0, 0, 0.153846154, 0, 650],\
[2, 0, 0, 16, 1, 0, 500, 0, 0, 0.076923077, 0, 900],\
[2, 5, 31, 26, 0, 0.104847, 500, 0, 0, 0.076923077, 0, 850],\
[2, 6, 16, 34, 1, 0.172947, 500, 0, 0, 0.153846154, 0, 900],\
[1, 4, 0, 16, 6, 0.206403, 500, 0, 0, 0.153846154, 0, 630],\
[1, 8, 20, 12, 5, 0.495897, 500, 0, 0, 0.153846154, 0, 500],\
[1, 8, 46, 8, 6, 0.495897, 500, 0, 0, 0.153846154, 0, 250],\
[2, 0, 0, 4, 8, 0, 500, 0, 0, 0.076923077, 0, 550],\
[2, 6, 602, 0, 0, 0, 500, 0, 0, 0, 0, 20],\
[0, 12, 5, 21, 0, 0.158674, 645, 0, 0, 0.153846154, 0, 850],\
[0, 12, 20, 21, 0, 0.158674, 645, 0, 0, 0.153846154, 0, 700],\
[1, 0, 0, 33, 0, 0.041473, 645, 0, 0, 0.230769231, 0, 890],\
[1, 0, 0, 12, 2, 0.147325, 500, 0, 0, 0.076923077, 0, 780],\
[1, 8, 296, 0, 0, 2.891695, 521, 0, 0, 0, 0, 1],\
[1, 0, 0, 4, 0, 0.098953, 445, 0, 0, 0.076923077, 0, 600],\
[1, 0, 0, 4, 0, 0.143443, 500, 0, 0, 0.076923077, 0, 500],\
[0, 8, 20, 0, 0, 1.110002, 833, 0, 0, 0, 0, 100],\
[0, 0, 0, 8, 2, 0, 833, 0, 0, 0.076923077, 0, 300],\
[1, 4, 60, 20, 6, 0.78685, 833, 0, 0, 0.153846154, 0, 100],\
[1, 4, 112, 20, 6, 0.78685, 833, 0, 0, 0.153846154, 0, 1],\
[1, 0, 0, 21, 10, 0.305556, 500, 0, 0, 0.307692308, 0, 150],\
[1, 0, 0, 21, 10, 0.453743, 500, 0, 0, 0.307692308, 0, 300],\
[0, 0, 0, 8, 0, 0, 570, 0, 0, 0, 0, 500],\
[0, 10, 10, 8, 0, 0.325975, 570, 0, 0, 0.076923077, 0, 450],\
[1, 7, 16, 15, 1, 0.266311, 570, 0, 0, 0.076923077, 0, 450],\
[1, 1, 32, 30, 4, 0.134606, 570, 0, 0, 0.230769231, 0, 250],\
[1, 0, 0, 32, 5, 0.105576, 570, 0, 0, 0.230769231, 0, 430],\
[1, 4, 34, 32, 5, 0.519103, 500, 0, 0, 0.230769231, 0, 350],\
[1, 0, 0, 12, 1, 0.109559, 669, 0, 0, 0.076923077, 0, 600],\
[11, 4, 15, 2, 3, 0.235709, 500, 0, 1, 0, 2, 900],\
[11, 4, 15, 1, 6, 0.504134, 500, 0, 1, 0, 2, 534],\
[2, 0, 0, 15, 9, 0.075403, 500, 0, 0, 0.076923077, 0, 573],\
[10, 0, 0, 51, 11, 2.211951, 500, 0, 0, 0.307692308, 7, 547],\
[9, 0, 0, 28, 4, 0.328037, 500, 0, 0, 0.230769231, 0, 747],\
[9, 2, 0, 0, 0, 0.166666, 500, 0, 1, 0.076923077, 4, 448],\
[8, 0, 0, 4, 1, 0, 500, 0, 1, 0, 1, 719],\
[3, 4, 15, 8, 1, 0.150237, 500, 0, 1, 0, 0, 827],\
[7, 138, 35, 37, 1, 0.414154, 500, 0, 1, 0.076923077, 3, 950],\
[6, 19, 41, 84, 1, 0.41248, 500, 0, 0, 0.230769231, 0, 750],\
[1, 6, 10, 0, 0, 0.232647, 500, 0, 1, 0, 0, 700],\
[0, 10, 27, 0, 0, 0.411712, 4, 0, 0, 0, 0, 520],\
[3, 31, 45, 80, 0, 0.266299, 500, 0, 0, 0.153846154, 0, 750],\
[3, 24, 49, 2, 1, 0.981102, 500, 0, 0, 0.076923077, 0, 550],\
[1, 12, 31, 11, 1, 0.333551, 500, 0, 0, 0.153846154, 0, 500],\
[0, 18, 30, 13, 2, 0.602826, 406, 0, 0, 0.076923077, 0, 580],\
[2, 2, 31, 0, 0, 1, 500, 0, 0, 0, 0, 427],\
[1, 18, 40, 83, 1, 0.332792, 500, 0, 0, 0.307692308, 0, 485],\
[2, 14, 35, 9, 3, 0.39671, 500, 0, 1, 0.076923077, 3, 664],\
[2, 88, 32, 7, 2, 0.548066, 500, 0, 1, 0, 1, 90],\
[2, 26, 26, 32, 2, 0.415991, 500, 0, 0, 0.153846154, 0, 90],\
[1, 14, 30, 11, 1, 0.51743, 599, 0, 0, 0.153846154, 0, 300],\
[1, 15, 28, 26, 0, 0.4413, 500, 0, 0, 0.076923077, 0, 610],\
[1, 17, 50, 34, 1, 0.313789, 500, 0, 0, 0.230769231, 0, 450],\
[0, 4, 15, 0, 0, 0.535163, 500, 0, 0, 0, 0, 375],\
[0, 8, 23, 0, 0, 0.51242, 500, 0, 0, 0, 0, 550],\
[3, 6, 44, 2, 3, 0.268062, 500, 0, 1, 0, 2, 744],\
[6, 38, 51, 35, 0, 0.28396, 500, 0, 1, 0.076923077, 1, 980],\
[6, 5, 63, 6, 5, 0.566661, 500, 0, 0, 0.153846154, 0, 850],\
[6, 0, 0, 0, 0, 0.174852, 500, 0, 0, 0, 0, 800],\
[6, 4, 60, 6, 3, 0.517482, 500, 0, 0, 0.076923077, 0, 750],\
[5, 16, 52, 49, 4, 0.378441, 500, 0, 1, 0.153846154, 6, 720],\
[5, 26, 84, 103, 1, 0.472361, 500, 0, 0, 0.230769231, 0, 300],\
[1, 6, 34, 36, 1, 0.298553, 500, 0, 1, 0.153846154, 0, 628],\
[5, 6, 65, 34, 0, 0.301907, 500, 0, 0, 0.153846154, 0, 710],\
[3, 16, 177, 29, 10, 0.501831, 500, 1, 0, 0.153846154, 0, 40],\
[2, 5, 45, 0, 0, 0.351668, 500, 0, 0, 0, 0, 708],\
[2, 7, 57, 7, 4, 0.432374, 500, 0, 0, 0.153846154, 0, 753],\
[1, 1, 75, 36, 0, 0.154085, 500, 0, 0, 0.076923077, 0, 610],\
[1, 16, 63, 13, 2, 0.331244, 500, 0, 0, 0.076923077, 0, 620],\
[1, 3, 55, 9, 0, 0.377253, 500, 0, 0, 0.076923077, 0, 640],\
[1, 1, 75, 5, 5, 0.877696, 500, 0, 0, 0.076923077, 0, 480],\
[1, 0, 0, 8, 5, 0.208742, 500, 0, 0, 0.153846154, 0, 520],\
[1, 3, 55, 29, 0, 0.228812, 678, 0, 0, 0.153846154, 0, 547],\
[1, 0, 0, 2, 2, 0.090459, 553, 0, 0, 0.076923077, 0, 535],\
[0, 4, 29, 0, 0, 0.292161, 500, 0, 0, 0, 0, 594],\
[1, 3, 64, 18, 6, 0.602431, 500, 0, 0, 0.230769231, 0, 500],\
[6, 9, 40, 74, 0, 0.567179, 500, 0, 0, 0.076923077, 0, 910],\
[4, 10, 65, 14, 1, 0.423915, 500, 0, 1, 0, 1, 713],\
[2, 0, 0, 6, 1, 0.114637, 500, 0, 0, 0.076923077, 0, 650],\
[5, 18, 74, 34, 0, 0.489314, 500, 0, 0, 0.153846154, 0, 500],\
[0, 6, 43, 9, 15, 0.599918, 612, 0, 0, 0.153846154, 0, 100],\
[4, 25, 64, 135, 0, 0.472659, 500, 0, 0, 0.230769231, 0, 560],\
[6, 3, 94, 12, 10, 0.31713, 500, 0, 0, 0.230769231, 0, 580],\
[1, 4, 69, 18, 9, 0.412528, 500, 0, 0, 0.307692308, 0, 362],\
[2, 21, 58, 21, 0, 0.53184, 500, 0, 0, 0.153846154, 0, 370],\
[0, 0, 0, 21, 4, 0.033438, 500, 0, 0, 0.153846154, 0, 500],\
[0, 10, 53, 20, 0, 0.619595, 500, 0, 0, 0.076923077, 0, 200],\
[2, 15, 63, 28, 2, 0.593453, 500, 0, 0, 0.153846154, 0, 574],\
[3, 2, 84, 21, 1, 0.302636, 500, 0, 0, 0.153846154, 0, 790],\
[4, 19, 47, 28, 0, 0.256892, 500, 0, 0, 0.076923077, 0, 748],\
[1, 0, 0, 0, 0, 0.119599, 500, 0, 0, 0, 0, 517],\
[3, 10, 53, 22, 0, 0.419703, 500, 0, 0, 0.153846154, 0, 800],\
[4, 7, 66, 70, 1, 0.362268, 500, 0, 0, 0.230769231, 0, 550],\
[0, 16, 88, 18, 3, 0.597145, 16, 0, 0, 0.153846154, 0, 50],\
[5, 8, 38, 0, 0, 0.666666, 500, 0, 0, 0, 0, 667]])

# Test data    
test_set = np.array([\
[2, 16, 87, 30, 0, 0.168057, 500, 0, 1, 0.153846154, 1, 760],\
[3, 5, 83, 6, 4, 0.273522, 500, 0, 0, 0.076923077, 0, 877],\
[1, 0, 0, 12, 0, 0.262797, 500, 0, 0, 0.153846154, 0, 596],\
[2, 15, 46, 28, 0, 0.495495, 500, 0, 0, 0.076923077, 0, 680],\
[1, 0, 0, 22, 9, 0.254813, 500, 0, 0, 0.230769231, 0, 450],\
[3, 19, 59, 12, 0, 0.437851, 500, 0, 0, 0.153846154, 0, 850],\
[4, 5, 28, 0, 0, 0.34559, 500, 0, 1, 0.076923077, 1, 800],\
[1, 5, 58, 0, 0, 0.385379, 500, 0, 0, 0, 0, 641],\
[1, 4, 65, 15, 1, 0.2945, 500, 0, 0, 0.153846154, 0, 644],\
[0, 0, 0, 9, 3, 0.421612, 500, 0, 0, 0.076923077, 0, 580],\
[3, 31, 83, 2, 2, 0.436883, 500, 0, 0, 0.076923077, 0, 410],\
[0, 0, 0, 18, 5, 0.044898, 377, 0, 0, 0.230769231, 0, 520],\
[0, 8, 49, 12, 3, 0.428529, 500, 0, 1, 0.076923077, 1, 370],\
[0, 22, 89, 2, 1, 0.819431, 500, 0, 0, 0.076923077, 0, 440],\
[3, 27, 63, 124, 0, 0.375306, 500, 0, 0, 0.076923077, 0, 880],\
[3, 20, 64, 18, 5, 0.439412, 500, 0, 1, 0.076923077, 3, 820],\
[1, 6, 34, 2, 12, 0.495654, 500, 0, 0, 0.076923077, 0, 653],\
[0, 14, 225, 0, 0, 1, 486, 0, 0, 0, 0, 1],\
[2, 8, 87, 32, 1, 0.829792, 500, 0, 0, 0.230769231, 0, 570],\
[2, 15, 46, 24, 4, 0.500442, 500, 0, 0, 0.153846154, 0, 568]])

# split datasets into independent and dependent variables
X_train, y_train = train_set[:, :-1], train_set[:, -1]    
X_test, y_test = test_set[:, :-1], test_set[:, -1]    

# feature scaling
sc = RobustScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.fit_transform(X_test)

# Linear model
reg = GetLinearModel(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("Linear", mae))

# Ridge Regression
reg = GetRidge(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("Ridge", mae))

# LASSO Regression
reg = GetLASSO(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("LASSO", mae))

# ElasticNet Regression
reg = GetElasticNet(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("ElasticNet", mae))

# Random Forest
reg = GetRandomForest(X_train, y_train)
y_pred = reg.predict(X_test)
mae = metrics.mean_absolute_error(y_test, y_pred)
print("%15s: %10f" % ("Random Forest", mae))


Output:

         Linear: 141.265089
          Ridge: 141.267797
          LASSO: 141.274700
     ElasticNet: 141.413544
  Random Forest:  90.776332

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM