簡體   English   中英

如何使用 sklearn 使用 10 倍交叉驗證獲得 10 個單獨的混淆矩陣

[英]How to get 10 individual confusion matrices using 10-fold cross validation using sklearn

我是機器學習的新手,所以這是我第一次使用 sklearn 包。 在這個分類問題中,我想獲得每個折疊的混淆矩陣,但我只得到一個,這是我到目前為止所做的。 我沒有在這里添加預處理部分。

from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import cross_val_predict

target = df["class"]
features = df.drop("class", axis=1)
split_df = round(0.8 * len(df))

features = features.sample(frac=1, random_state=0)
target = target.sample(frac=1, random_state=0)

trainFeatures, trainClassLabels = features.iloc[:split_df], target.iloc[:split_df]
testFeatures, testClassLabels = features.iloc[split_df:], target.iloc[split_df:]

tree = DecisionTreeClassifier(random_state=0)
tree.fit(X=trainFeatures, y=trainClassLabels)

y_pred = cross_val_predict(tree, X=features, y=target, cv=10)

conf_matrix = confusion_matrix(target, y_pred)
print("Confusion matrix:\n", conf_matrix)

您需要使用Kfold來提供拆分,而不是指定 cv=10。 例如:

from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import KFold, cross_val_predict
from sklearn.datasets import make_classification

features, target = make_classification(random_state=0)

tree = DecisionTreeClassifier(random_state=0)
kf = KFold(10,random_state=99,shuffle=True)

y_pred = cross_val_predict(tree, X=features, y=target, cv=kf)

conf_matrix = confusion_matrix(target, y_pred)
print("Confusion matrix:\n", conf_matrix)

Confusion matrix:
 [[41  9]
 [ 6 44]]

然后我們可以為每個折疊制作混淆矩陣:

lst = []
for train_index, test_index in kf.split(features):
    lst.append(confusion_matrix(target[test_index], y_pred[test_index]))
    

它看起來像這樣:

[array([[4, 0],
        [0, 6]]),
 array([[4, 3],
        [1, 2]]),
 array([[2, 0],
        [2, 6]]),
 array([[5, 1],
        [0, 4]]),
 array([[4, 1],
        [1, 4]]),
 array([[2, 2],
        [0, 6]]),
 array([[4, 0],
        [0, 6]]),
 array([[4, 1],
        [1, 4]]),
 array([[4, 1],
        [1, 4]]),
 array([[8, 0],
        [0, 2]])]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM