简体   繁体   English

scikit learn Random Forest Classifier 概率阈值

[英]scikit learn Random Forest Classifier probability threshold

I'm using sklearn RandomForestClassifier for a prediction task.我正在使用sklearn RandomForestClassifier进行预测任务。

from sklearn.ensemble import RandomForestClassifier

model = RandomForestClassifier(n_estimators=300, n_jobs=-1)
model.fit(x_train,y_train)
model.predict_proba(x_test)

There are 171 classes to predict.有 171 个类要预测。 I want to predict only those classes, where predict_proba(class) is at least 90%.我只想预测那些predict_proba(class)至少为 90% 的类。 Everything below should be set to 0 .以下所有内容都应设置为0

For example, given the following:例如,给定以下内容:

     1   2   3   4   5   6   7
0  0.0 0.0 0.1 0.9 0.0 0.0 0.0
1  0.2 0.1 0.1 0.3 0.1 0.0 0.2
2  0.1 0.1 0.1 0.1 0.1 0.4 0.1
3  1.0 0.0 0.0 0.0 0.0 0.0 0.0

my expected output is:我预期的 output 是:

0   4
1   0
2   0   
3   1

You can use numpy.argwhere as follows:您可以使用numpy.arg ,如下所示:

from sklearn.ensemble import RandomForestClassifier
import numpy as np

model = RandomForestClassifier(n_estimators=300, n_jobs=-1)
model.fit(x_train,y_train)
preds = model.predict_proba(x_test)

#preds = np.array([[0.0, 0.0, 0.1, 0.9, 0.0, 0.0, 0.0],
#                  [ 0.2, 0.1, 0.1, 0.3, 0.1, 0.0, 0.2],
#                  [ 0.1 ,0.1, 0.1, 0.1, 0.1, 0.4, 0.1],
#                  [ 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])

r = np.zeros(preds.shape[0], dtype=int)
t = np.argwhere(preds>=0.9)

r[t[:,0]] = t[:,1]+1
r
array([4, 0, 0, 1])

You can use list comprehensions:您可以使用列表推导:

import numpy as np

# dummy predictions - 3 samples, 3 classes
pred = np.array([[0.1, 0.2, 0.7],
                 [0.95, 0.02, 0.03],
                 [0.08, 0.02, 0.9]])

# first, keep only entries >= 0.9:
out_temp = np.array([[x[i] if x[i] >= 0.9 else 0 for i in range(len(x))] for x in pred])
out_temp
# result:
array([[0.  , 0.  , 0.  ],
       [0.95, 0.  , 0.  ],
       [0.  , 0.  , 0.9 ]])

out = [0 if not x.any() else x.argmax()+1 for x in out_temp]
out
# result:
[0, 1, 3]

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM