簡體   English   中英

Scikit-Learn:GridSearchCV 的自定義損失函數

[英]Scikit-Learn: Custom Loss Function for GridSearchCV

我正在參加 Kaggle 比賽( https://www.kaggle.com/c/house-prices-advanced-regression-techniques#evaluation ),它指出我的模型將通過以下方式評估:

Submissions are evaluated on Root-Mean-Squared-Error (RMSE) between the logarithm of the predicted value and the logarithm of the observed sales price. (Taking logs means that errors in predicting expensive houses and cheap houses will affect the result equally.)

我在文檔中找不到這個(它基本上是RMSE(log(truth), log(prediction) ),所以我開始編寫一個自定義記分器:

def custom_loss(truth, preds):
    truth_logs = np.log(truth)
    print(truth_logs)
    preds_logs = np.log(preds)
    numerator = np.sum(np.square(truth_logs - preds_logs))
    return np.sum(np.sqrt(numerator / len(truth)))

custom_scorer = make_scorer(custom_loss, greater_is_better=False)

兩個問題:

1) 我的自定義損失函數是否應該返回一個 numpy 分數數組(每個(真實,預測)對一個?還是應該是這些(真實,預測)對的總損失,返回一個數字?

我查看了文檔,但它們並不是非常有用:我的自定義損失函數應該返回什么。

2)當我跑步時:

xgb_model = xgb.XGBRegressor()
params = {"max_depth": [3, 4], "learning_rate": [0.05],
         "n_estimators": [1000, 2000], "n_jobs": [8], "subsample": [0.8], "random_state": [42]}
grid_search_cv = GridSearchCV(xgb_model, params, scoring=custom_scorer,
                             n_jobs=8, cv=KFold(n_splits=10, shuffle=True, random_state=42), verbose=2)

grid_search_cv.fit(X, y)

grid_search_cv.best_score_

我回來了:

-0.12137097567803554

這是非常令人驚訝的。 鑒於我的損失函數采用RMSE(log(truth) - log(prediction)) ,我不應該有一個負的best_score_

知道為什么它是負面的嗎?

謝謝!

1)您應該返回一個數字作為損失,而不是數組。 GridSearchCV 將根據此計分器的結果對參數進行排序。

順便說一下,您可以使用mean_squared_log_error而不是定義自定義指標,它mean_squared_log_error您的需求。

2)為什么它返回負數? - 沒有你的實際數據和完整的代碼,我們不能說。

你應該小心符號。

這里有兩個級別的優化:

  1. XGBRegressor擬合數據時優化的損失函數。
  2. 在網格搜索過程中優化的評分函數。

我更喜歡調用第二個評分函數而不是損失函數,因為損失函數通常是指在模型擬合過程中進行優化的術語。 但是,您的自定義函數僅指定 2. 而保持 1. 不變。 如果您想更改XGBRegressor的損失函數,請參見此處 大多數回歸模型都有幾個標准供您選擇,例如mean_square_errormean_absolute_error

請注意,目前不支持傳遞自定義損失函數(請參閱此處此處的原因)。

如果greater_is_better 為False,則make_scorer 函數符號翻轉

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM