繁体   English   中英

有没有办法在自定义 Tensorflow 模型中找到每个类的平均精度和召回率?

[英]Is there a way to find the average precision and recall of each class in the custom Tensorflow model?

我已经使用 TensorFlow 和 SSD MobileNet 训练了一个模型。 我能够找到模型的平均平均精度。 他们是一种找到模型中每个类的平均精度的方法。 我正在使用 TensorFlow 2.5 版本。 提前致谢

您可以按照下面的代码使用 sklearn。 对于混淆矩阵和分类报告,您需要提供 y_predict 和 y_true。 训练模型后,对测试集进行预测。 我假设您的代码中有 y_true 作为类的标签。 我将假设它们存在于一个名为 y_true 的列表中,并且与您对 model.predict 的输入的顺序相同。 我还将假设您有一个名为 classes 的列表,它们是按顺序排列的类的名称。 例如,如果猫是标签 0,狗是标签 1,那么 classes=[cats, dogs]

from sklearn.metrics import confusion_matrix, classification_report
preds=model.predict ---etc
ypredict=[]
for p in preds:
    index=np.argmax(p)
    y_predict.append(index)
y_true= np.array(y_true)        
y_predict=np.array(y_predict)    
# create a confusion matrix 
cm = confusion_matrix(y_true, y_predict ) 
# code below formats the confusion matrix plot       
length=len(classes)
if length<8:
    fig_width=8
    fig_height=8
 else:
    fig_width= int(length * .5)
    fig_height= int(length * .5)
plt.figure(figsize=(fig_width, fig_height))
sns.heatmap(cm, annot=True, vmin=0, fmt='g', cmap='Blues', cbar=False)       
plt.xticks(np.arange(length)+.5, classes, rotation= 90)
plt.yticks(np.arange(length)+.5, classes, rotation=0)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.show()
clr = classification_report(y_true, y_pred, target_names=classes)
print("Classification Report:\n----------------------\n", clr)

以下是分类报告的示例

Classification Report:
----------------------
               precision    recall  f1-score   support

      Banana       1.00      1.00      1.00        15
       Bread       1.00      1.00      1.00        15
        Eggs       1.00      1.00      1.00        15
        Milk       1.00      1.00      1.00        15
       Mixed       1.00      1.00      1.00        12
      Potato       1.00      1.00      1.00        15
     Spinach       1.00      1.00      1.00        15
      Tomato       1.00      1.00      1.00        15

    accuracy                           1.00       117
   macro avg       1.00      1.00      1.00       117
weighted avg       1.00      1.00      1.00       117

我面临同样的问题。 找不到这个 y_pred 或 y_test 来自哪里。 SSD MobileNetV2 训练期间没有定义 y_pred 或 y_test。 如果有人有答案请分享

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM