簡體   English   中英

Keras + Tensorflow中的混淆矩陣

[英]Confusion Matrix in Keras+Tensorflow

Q1

我已經訓練了CNN模型,並將其另存為model.h5 我正在嘗試檢測3個物體。 說“貓”,“狗”和“其他”。 我的測試集有300張圖像,每個類別有100張圖像。 前100個是“貓”,第二個100是“狗”,第3個100是“其他”。 我正在使用flow_from_directoryImageDataGeneratorflow_from_directory 這是示例代碼:

test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_directory(
        test_dir,
        target_size=(150, 150),
        batch_size=20,
        class_mode='sparse',
        shuffle=False)

現在使用

from sklearn.metrics import confusion_matrix

cnf_matrix = confusion_matrix(y_test, y_pred)

我需要y_testy_pred 我可以使用以下代碼獲取y_pred

probabilities = model.predict_generator(test_generator)
y_pred = np.argmax(probabilities, axis=1)
print (y_pred)

[0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 1 0 0 0 0 0 0 1 0 0 0
 0 0 0 0 1 0 0 0 0 1 2 0 2 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 1 1
 0 2 0 0 0 0 1 0 0 0 0 0 0 1 0 2 0 1 0 0 1 0 0 1 0 0 1 1 1 1 1 1 1 1 1 1 2
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1
 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 2 2 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 2
 1 1 1 1 1 2 1 1 1 1 1 2 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 1 2 2 2 1 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2]

這基本上是將對象預測為0,1和2。現在我知道前100個對象(貓)為0,第二個100對象(狗)為1,第3個100對象(其他)為2。是否手動創建列表?使用numpy ,其中前100點為0,第二個100點為1,第3個100點為2以得到y_test 是否有任何Keras類可以做到這一點(創建y_test )?

Q2

如何查看錯誤檢測的對象。 如果您查看print(y_pred) ,則第三個點是1,這是錯誤預測的。 如何在不手動進入“ test_dir”文件夾的情況下看到該圖像?

由於您沒有使用任何增強和shuffle=False ,因此可以簡單地從生成器獲取圖像:

imgBatch = next(test_generator)
    #it may be interesting to create the generator again if 
    #you're not sure it has output exactly all images before

使用繪圖庫(例如Pillow(PIL)或MatplotLib)在imgBatch中繪制每個圖像。

要僅繪制所需的圖像, y_testy_pred進行比較:

compare = y_test == y_pred

position = 0
while position < len(y_test):
    imgBatch = next(test_generator)
    batch = imgBatch.shape[0]

    for i in range(position,position+batch):
        if compare[i] == False:
            plot(imgBatch[i-position])

    position += batch

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM