简体   繁体   English

从混淆矩阵中显示错误分类的数字

[英]Display misclassified digits from confusion matrix

I wrote a function to find the confusion matrix of my model:我写了一个 function 来找到我的 model 的混淆矩阵:

NN_model = KNeighborsClassifier(n_neighbors=1)
NN_model.fit(mini_train_data, mini_train_labels)
# Create the confusion matrix for the dev data
confusion = confusion_matrix(dev_labels, NN_model.predict(dev_data))
print(confusion)

我得到以下输出

But I am having trouble in displaying images of above 5 digits that are often confused with others.但是我在显示经常与其他数字混淆的 5 位以上的图像时遇到了麻烦。 But when I try the following code, I am not getting the expected results.但是当我尝试下面的代码时,我没有得到预期的结果。

index = 0
misclassifiedIndexes = []
for label, predict in zip(dev_labels, predictions):
     if label != predict: 
        misclassifiedIndexes.append(index)
        index +=1

plt.figure(figsize=(20,4))
for plotIndex, badIndex in enumerate(misclassifiedIndexes[0:5]):
    plt.subplot(1, 5, plotIndex + 1)
    plt.imshow(np.reshape(dev_data[badIndex], (28,28)), cmap=plt.cm.gray)
    plt.title('Predict: {}, Actual: {}'.format(predictions[badIndex], dev_labels[badIndex]), fontsize = 15)

这是我得到的,但理想情况下我想显示错误分类的数字

Can you please take a look to see what is going wrong with my code?你能看看我的代码有什么问题吗? Thank you!谢谢!

As such I cannot issue with your code.因此,我不能对您的代码提出问题。 Hence, I'm providing here a reproducible code.因此,我在这里提供了一个可重现的代码。

You can use the np.where on the Boolean comparison between predictions and actuals.您可以使用np.where上的 np.where 比较预测值和实际值。

Try this example:试试这个例子:

from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split

X, y = load_digits(return_X_y=True)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4)

NN_model = KNeighborsClassifier(n_neighbors=1)
NN_model.fit(X_train, y_train)
# Create the confusion matrix for the dev data
from sklearn.metrics import confusion_matrix
predictions = NN_model.predict(X_test)
confusion = confusion_matrix(y_test, predictions)

import matplotlib.pyplot as plt
misclassifiedIndexes = np.where(y_test!=predictions)[0]


fig, ax = plt.subplots(4, 3,figsize=(15,8))
ax = ax.ravel()
for i, badIndex in enumerate(misclassifiedIndexes):
    ax[i].imshow(np.reshape(X_test[badIndex], (8, 8)), cmap=plt.cm.gray)
    ax[i].set_title(f'Predict: {predictions[badIndex]}, '
                    f'Actual: {y_test[badIndex]}', fontsize = 10)
    ax[i].set(frame_on=False)
    ax[i].axis('off')
plt.box(False)
plt.axis('off')

在此处输入图像描述

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM