簡體   English   中英

有沒有辦法在自定義 Tensorflow 模型中找到每個類的平均精度和召回率?

[英]Is there a way to find the average precision and recall of each class in the custom Tensorflow model?

我已經使用 TensorFlow 和 SSD MobileNet 訓練了一個模型。 我能夠找到模型的平均平均精度。 他們是一種找到模型中每個類的平均精度的方法。 我正在使用 TensorFlow 2.5 版本。 提前致謝

您可以按照下面的代碼使用 sklearn。 對於混淆矩陣和分類報告,您需要提供 y_predict 和 y_true。 訓練模型后,對測試集進行預測。 我假設您的代碼中有 y_true 作為類的標簽。 我將假設它們存在於一個名為 y_true 的列表中,並且與您對 model.predict 的輸入的順序相同。 我還將假設您有一個名為 classes 的列表,它們是按順序排列的類的名稱。 例如,如果貓是標簽 0,狗是標簽 1,那么 classes=[cats, dogs]

from sklearn.metrics import confusion_matrix, classification_report
preds=model.predict ---etc
ypredict=[]
for p in preds:
    index=np.argmax(p)
    y_predict.append(index)
y_true= np.array(y_true)        
y_predict=np.array(y_predict)    
# create a confusion matrix 
cm = confusion_matrix(y_true, y_predict ) 
# code below formats the confusion matrix plot       
length=len(classes)
if length<8:
    fig_width=8
    fig_height=8
 else:
    fig_width= int(length * .5)
    fig_height= int(length * .5)
plt.figure(figsize=(fig_width, fig_height))
sns.heatmap(cm, annot=True, vmin=0, fmt='g', cmap='Blues', cbar=False)       
plt.xticks(np.arange(length)+.5, classes, rotation= 90)
plt.yticks(np.arange(length)+.5, classes, rotation=0)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.show()
clr = classification_report(y_true, y_pred, target_names=classes)
print("Classification Report:\n----------------------\n", clr)

以下是分類報告的示例

Classification Report:
----------------------
               precision    recall  f1-score   support

      Banana       1.00      1.00      1.00        15
       Bread       1.00      1.00      1.00        15
        Eggs       1.00      1.00      1.00        15
        Milk       1.00      1.00      1.00        15
       Mixed       1.00      1.00      1.00        12
      Potato       1.00      1.00      1.00        15
     Spinach       1.00      1.00      1.00        15
      Tomato       1.00      1.00      1.00        15

    accuracy                           1.00       117
   macro avg       1.00      1.00      1.00       117
weighted avg       1.00      1.00      1.00       117

我面臨同樣的問題。 找不到這個 y_pred 或 y_test 來自哪里。 SSD MobileNetV2 訓練期間沒有定義 y_pred 或 y_test。 如果有人有答案請分享

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM