[英]NameError: name 'plot_confusion_matrix' is not defined
我正在嘗試使用 VGG16 進行分類 model 但在項目結束時我遇到了獲取混淆矩陣的錯誤。 下面給出代碼,
導入的包和模塊有:
import os
import keras
import numpy as np
import tensorflow as tf
from keras.models import Model
import matplotlib.pyplot as plt
from keras.optimizers import Adam
from keras.applications import MobileNet
from sklearn.metrics import confusion_matrix
from keras.layers.core import Dense, Activation
from keras.metrics import categorical_crossentropy
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import ImageDataGenerator
from keras.applications.mobilenet import preprocess_input
from tensorflow.keras.preprocessing import image_dataset_from_directory
注意:為了縮短我只是跳過鏈接的數據集
下面定義 VGG16:
vgg16_model = keras.applications.vgg16.VGG16()
vgg16_model.summary()
現在,定義 model:
model = Sequential()
for layer in vgg16_model.layers:
model.add(layer)
for layer in model.layers:
layer.trainable = False
model.add(Dense(2, activation='softmax'))
編譯 model:
model.compile(Adam(lr=.0001), loss='categorical_crossentropy', metrics=['accuracy'])
安裝 model:
model.fit_generator(train_batches, steps_per_epoch=4, validation_data=valid_batches, validation_steps=4, epochs=10, verbose=2)
現在對於混淆矩陣:
test_imgs, test_labels = next(test_batches)
plots(test_imgs, titles=test_labels)
test_labels = test_labels[:,0]
predictions = model.predict_generator(test_batches, steps=1, verbose=0)
cm = confusion_matrix(test_labels, np.round(predictions[:,0]))
下面我遇到了一個錯誤,請關注下面的代碼,
cm_plot_labels = ['diseaseAffectedEggplant','freshEggplant']
plot_confusion_matrix(cm, cm_plot_labels, title="Confusion Matrix") // this line, I faced an error
錯誤如下,
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-28-43b96d543746> in <module>()
1 cm_plot_labels = ['diseaseAffectedEggplant','freshEggplant']
----> 2 plot_confusion_matrix(cm, cm_plot_labels, title="Confusion Matrix")
NameError: name 'plot_confusion_matrix' is not defined
您需要從sklearn.metrics
模塊導入plot_confusion_matrix
:
from sklearn.metrics import plot_confusion_matrix
請參閱文檔。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.