[英]Python Scikit - bad input shape when calling sklearn.metrics.precision_recall_curve
I'm trying to build a PRC (precision-recall curve) for a CatBoostClassifier
. 我正在尝试为
CatBoostClassifier
构建PRC(精确调用曲线)。
But when I'm calling sklearn.metrics.precision_recall_curve(y_test, y_score)
I'm getting ValueError: bad input shape (11912, 2)
. 但是当我调用
sklearn.metrics.precision_recall_curve(y_test, y_score)
我得到ValueError: bad input shape (11912, 2)
。
What could be wrong with my current approach? 我目前的方法有什么问题? And what do I need to fix here to provide a correct shape?
我需要在此处修复什么以提供正确的形状?
import sklearn
from sklearn import metrics
y_score = model.predict_proba(X_test)
prc_auc = sklearn.metrics.precision_recall_curve(y_test, y_score)
//Here is how I build a model //这是我建立模型的方式
model = CatBoostClassifier(
iterations=50,
random_seed=63,
learning_rate=0.15,
custom_loss=['Accuracy', 'Precision', 'Recall', 'AUC']
)
model.fit(
X_train, y_train,
cat_features=cat_features,
eval_set=(X_test, y_test),
verbose=10,
plot=True
);
The trivial answer is that CatBoostClassifier.model.predict_proba
returns a 2d array; 最简单的答案是
CatBoostClassifier.model.predict_proba
返回2d数组。 sklearn.model.precision_recall_curve
requires a 1d array (or a 2d array with one column, whichever). sklearn.model.precision_recall_curve
需要一个1d数组(或带有一列的2d数组)。
The documentation for CatBoostClassifier
says that predict_proba()
returns numpy.array
, and provides no other information about this method. CatBoostClassifier
的文档说, predict_proba()
返回numpy.array
,并且不提供有关此方法的其他信息。 So I hate the documentation for this package now. 因此,我现在讨厌此软件包的文档。
Walking through some poorly-commented code gets me to: 通过一些评论欠佳的代码,我可以:
if prediction_type == 'Probability':
predictions = np.transpose([1 - predictions, predictions])
return predictions
I'm guessing that column 0 is the probability of class 0, and column 1 is the probability of class 1. So pick whichever of those things your test aligns with and use that column only. 我猜第0列是类别0的概率,而第1列是类别1的概率。因此,请选择测试与之匹配的任何事物,并仅使用该列。
prc_auc = sklearn.metrics.precision_recall_curve(y_test, y_score[:, 1])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.