簡體   English   中英

如何為scikit學習隨機森林模型設置閾值

[英]how to set threshold to scikit learn random forest model

在看到precision_recall_curve之后,如果我想設置threshold = 0.4,如何在我的隨機森林模型中實現0.4(二元分類),對於任何概率<0.4,將其標記為0,對於任何> = 0.4,將其標記為1。

from sklearn.ensemble import RandomForestClassifier
  random_forest = RandomForestClassifier(n_estimators=100, oob_score=True, random_state=12)
  random_forest.fit(X_train, y_train)
from sklearn.metrics import accuracy_score
  predicted = random_forest.predict(X_test)
accuracy = accuracy_score(y_test, predicted)

文檔精確召回

假設您正在進行二進制分類,這很容易:

threshold = 0.4

predicted_proba = random_forest.predict_proba(X_test)
predicted = (predicted_proba [:,1] >= threshold).astype('int')

accuracy = accuracy_score(y_test, predicted)
random_forest = RandomForestClassifier(n_estimators=100)
random_forest.fit(X_train, y_train)

threshold = 0.4

predicted = random_forest.predict_proba(X_test)
predicted[:,0] = (predicted[:,0] < threshold).astype('int')
predicted[:,1] = (predicted[:,1] >= threshold).astype('int')


accuracy = accuracy_score(y_test, predicted)
print(round(accuracy,4,)*100, "%")

這帶有一個錯誤,指的是最后一個精度部分“ValueError:無法處理二進制和多標簽指示符的混合”

sklearn.metrics.accuracy_score采用 1 d 數組,但您的預測數組是 2 d。 這會帶來一個錯誤。
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM