簡體   English   中英

無法打印模型的混淆矩陣

[英]Cannot print confusion matrix for model

我實現了一個MLP ,並且運行良好。 但是,我在嘗試打印混淆矩陣時遇到問題。

我的模型定義為...

logits = layers(X, weights, biases)

哪里...

def layers(x, weights, biases):
    layer_1 = tf.add(tf.matmul(x, weights['h1']), biases['b1'])
    layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])
    out_layer = tf.matmul(layer_2, weights['out']) + biases['out']

    return out_layer

我在mnist數據集上訓練模型。 經過培訓,我能夠成功打印出模型的准確性...

pred = tf.nn.softmax(logits)

correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print("Accuracy: ", accuracy.eval({X:mnist.test.images, y:mnist.test.labels}))

准確性給我90%。 現在,我要打印出結果的混淆矩陣。 我嘗試了以下...

confusion = tf.confusion_matrix(
         labels=mnist.test.labels, predictions=correct_prediction)

但這給了我錯誤...

ValueError:無法擠壓dim [1],預期尺寸為1,輸入形狀為[10000,10]的“ confusion_matrix / remove_squeezable_dimensions / Squeeze”(操作數:“ Squeeze”)得到10。

打印混淆矩陣的正確方法是什么? 我已經掙扎了一段時間。

看起來tf.confusion_matrix的參數之一的第二個昏暗tf.confusion_matrix 10。 問題是mnist.test.labelscorrect_prediction是否被一鍵編碼? 那就可以解釋了。 您需要在那里的標簽作為一個暗張量。 您可以打印這兩個張量的形狀嗎?

而且看起來correct_prediction是一個布爾張量,用於標記您的預測是否准確。 對於混淆矩陣,您需要預測的標簽,而不是tf.argmax( pred, 1 ) 同樣,如果您的標簽是一鍵編碼的,則需要對它們進行解碼以解決混淆矩陣。 因此,請嘗試將此行confusion

confusion = tf.confusion_matrix(
     labels = tf.argmax( mnist.test.labels, 1 ),
     predictions = tf.argmax( pred, 1 ) )

為了打印混淆矩陣本身,必須將eval與最終結果一起使用:

print(confusion.eval({x:mnist.test.images, y:mnist.test.labels}))

這對我有用:

confusion = tf.confusion_matrix(
       labels = tf.argmax( mnist.test.labels, 1 ),
       predictions = tf.argmax( y, 1 ) )
   print(confusion.eval({x:mnist.test.images, y_:mnist.test.labels})) 

[[ 960    0    2    2    1    5    7    2    1    0]
 [   0 1113    3    2    0    1    4    2   10    0]
 [   6    7  941   15   12    2   10    8   27    4]
 [   2    1   27  926    1   12    1    8   24    8]
 [   1    2    6    1  928    0    9    2    9   24]
 [   9    2    8   51   12  729   15    9   50    7]
 [  13    3   10    2    9    9  905    2    5    0]
 [   1    9   28    8   11    1    0  938    3   29]
 [   6   10    7   19    9   13    8    5  891    6]
 [   9    7    2    9   43    5    0   14   12  908]]

對於NLTK混淆矩陣,您需要一個列表

classifier = NaiveBayesClassifier.train(trainfeats)
refsets = collections.defaultdict(set)
testsets = collections.defaultdict(set)

lsum = []
tsum = []

for i, (feats, label) in enumerate(testfeats):
  refsets[label].add(i)
  observed = classifier.classify(feats)
  testsets[observed].add(i)
  lsum.append(label)
  tsum.append(observed

print (nltk.ConfusionMatrix(lsum,tsum))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM