簡體   English   中英

類數,4,與 target_names 的大小不匹配,6。嘗試指定標簽參數

[英]Number of classes, 4, does not match size of target_names, 6. Try specifying the labels parameter

當我嘗試制作我的 CNN 模型的混淆矩陣時,我遇到了一些問題。當我運行代碼時,它返回一些錯誤,如:

print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))

Traceback (most recent call last):

  File "<ipython-input-102-82d46efe536a>", line 1, in <module>
    print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))

  File "G:\anaconda_installation_file\lib\site-packages\sklearn\metrics\classification.py", line 1543, in classification_report
    "parameter".format(len(labels), len(target_names))

ValueError: Number of classes, 4, does not match size of target_names, 6. Try specifying the labels parameter

我已經搜索過要解決這個問題,但仍然沒有得到完美的解決方案。 我是這個領域的新手,有人能幫我嗎? 謝謝。

from sklearn.metrics import classification_report,confusion_matrix
import itertools

Y_pred = model.predict(X_test)
print(Y_pred)
y_pred = np.argmax(Y_pred, axis=1)
print(y_pred)

target_names = ['class 0(cardboard)', 'class 1(glass)', 'class 2(metal)','class 3(paper)', 'class 4(plastic)','class 5(trash)']

print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))

問題是您有 6 個標簽名稱: 'class 0(cardboard)', 'class 1(glass)', 'class 2(metal)','class 3(paper)', 'class 4(plastic)','class 5(trash)'

但是你的混淆print(y_pred)只有 4 個類,當你打印時: print(y_pred) :你會得到帶有0,1,2,3數字,或者當你print(y_test)你會得到來自0,1,2,3數字0,1,2,3 ,它應該有助於刪除:

print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))

從您的代碼中,不知何故您沒有 6 個預測/測試類。

這里還有一個如何繪制混淆矩陣的示例: 如何繪制混淆矩陣?

你應該更好地提出你的問題! 我在做一些假設!
問題是:

target_names = ['class 0(紙板)','class 1(玻璃)','class 2(金屬)','class 3(紙)','class 4(塑料)','class 5(垃圾)' ]

有 6 個類,您的模型只能預測 4 個類,這些類會引發錯誤,因為混淆矩陣提供了 4 個類(它應該是 6x6 而不是 6x4)。
要糾正這個問題,也只需提供類標簽。 對於 ecample,如果有 3 個標簽(在預測變量中),即 1,2,3

打印(分類報告(y_true,y_pred,標簽=[1、2、3]))

請參閱此處的文檔https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html

PS:

  1. 您的模型表現不佳。

  2. 您的數據集可能存在類不平衡問題。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM