简体   繁体   中英

Is there a way to find the average precision and recall of each class in the custom Tensorflow model?

I have trained a model using TensorFlow and SSD MobileNet. I was able to find the mean average precision of the model. Is their a way to find the average precision of each class in the models. I am using Tensorflow 2.5 version. Thanks in advance

You can use sklearn per the code below. For both the confusion matrix and the classification report you need to provide y_predict and y_true. After you train you model then do predictions on the test set. Somewhere I assume you have y_true in your code as the label for the classes. I will assume they are present in a list called y_true and are in the SAME order as your inputs to model.predict. I will also assume you have a list called classes which are the names of your classes in order. For example if cats is label 0 and dogs is label 1 then classes=[cats, dogs]

from sklearn.metrics import confusion_matrix, classification_report
preds=model.predict ---etc
ypredict=[]
for p in preds:
    index=np.argmax(p)
    y_predict.append(index)
y_true= np.array(y_true)        
y_predict=np.array(y_predict)    
# create a confusion matrix 
cm = confusion_matrix(y_true, y_predict ) 
# code below formats the confusion matrix plot       
length=len(classes)
if length<8:
    fig_width=8
    fig_height=8
 else:
    fig_width= int(length * .5)
    fig_height= int(length * .5)
plt.figure(figsize=(fig_width, fig_height))
sns.heatmap(cm, annot=True, vmin=0, fmt='g', cmap='Blues', cbar=False)       
plt.xticks(np.arange(length)+.5, classes, rotation= 90)
plt.yticks(np.arange(length)+.5, classes, rotation=0)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.show()
clr = classification_report(y_true, y_pred, target_names=classes)
print("Classification Report:\n----------------------\n", clr)

Below is an example of a classification report

Classification Report:
----------------------
               precision    recall  f1-score   support

      Banana       1.00      1.00      1.00        15
       Bread       1.00      1.00      1.00        15
        Eggs       1.00      1.00      1.00        15
        Milk       1.00      1.00      1.00        15
       Mixed       1.00      1.00      1.00        12
      Potato       1.00      1.00      1.00        15
     Spinach       1.00      1.00      1.00        15
      Tomato       1.00      1.00      1.00        15

    accuracy                           1.00       117
   macro avg       1.00      1.00      1.00       117
weighted avg       1.00      1.00      1.00       117

I am facing this same problem. Can't find where this y_pred or y_test come from. There is no y_pred or y_test defined during the SSD MobileNetV2 training. If anyone have an answer please share

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM