簡體   English   中英

Keras 中的多類分類報告

[英]Multiclass classification report in Keras

所以,我手頭有 4 個類別的分類問題。 我已經建立了一個ANN如下:

import tensorflow as tf
from keras.layers import Flatten
ann=tf.keras.models.Sequential()
ann.add(tf.keras.layers.Dense(units=17,activation='relu')) 
ann.add(tf.keras.layers.Dense(units=17,activation='relu'))
ann.add(tf.keras.layers.Dense(units=17,activation='relu'))
ann.add(tf.keras.layers.Dense(units=17,activation='relu'))
ann.add(tf.keras.layers.Dense(units=4,activation='softmax')) #output
ann.add(Flatten())

ann.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics = 
['sparse_categorical_accuracy'])
ann.fit(scaled_xtrain,ytrain,batch_size=8, epochs=20,validation_split=0.2)


test_loss, test_acc = ann.evaluate(scaled_xtest, ytest)
train_loss, train_acc = ann.evaluate(scaled_xtrain, ytrain)
print('Test Accuracy: ', test_acc, '\nTest Loss: ', test_loss)
print('Train Accuracy: ', train_acc, '\nTrain Loss: ', train_loss)

我想查看它的分類報告,格式如下:

             precision    recall  f1-score   support

       0       0.81      0.76      0.78        88
       1       0.51      0.57      0.54        53
       2       0.62      0.59      0.60        71
       3       0.69      0.72      0.71        57

accuracy                           0.67       269
macro avg       0.66      0.66      0.66       269
weighted avg       0.67      0.67      0.67       269

但是,當我編寫代碼時:

from sklearn.metrics import confusion_matrix
ypred= ann.predict_classes(xtest)
ypred= (ypred >0.5)
matrix = confusion_matrix(ytest,ypred)

我收到以下錯誤:

TypeError                                 Traceback (most recent 
call last)
 <ipython-input-41-42c5ccb6a924> in <module>()
  2 ypred= ann.predict(xtest)
  3 ypred= (ypred >0.5)
  ----> 4 matrix = confusion_matrix(ytest.argmax(axis=0), 
  ypred.argmax(axis=0))

  4 frames
  /usr/local/lib/python3.7/dist- 
  packages/sklearn/utils/validation.py in _num_samples(x)
  267         if len(x.shape) == 0:
  268             raise TypeError(
  --> 269                 "Singleton array %r cannot be 
  considered a valid collection." % x
  270             )
  271         # Check that shape is returning an integer or 
  default to len

  TypeError: Singleton array 2 cannot be considered a valid 
  collection.

請幫忙!!

對於您的問題,如果您想要分類報告,請嘗試sklearn.metrics.classification_report

y_pred_class = model.predict_classes(test_images) ## xtest => test image
y_pred = model.predict(test_images)               ## xtest => test image
y_test_class = np.argmax(test_labels, axis=1)     ## ytest => test image's label

print(classification_report(y_test_class, y_pred_class))
print(confusion_matrix(test_labels.argmax(axis=1), y_pred.argmax(axis=1)))

在此處輸入圖像描述

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM