[英]how to create confusion matrix for classification in tensorflow
I have CNN model which has 4 output nodes, and I am trying to compute the confusion matrix so that i can know the individual class accuracy. 我有CNN模型,它有4个输出节点,我试图计算混淆矩阵,这样我就可以知道各个类的准确性。 I am able to compute the overall accuracy.
我能够计算出整体的准确性。 In the link here , Igor Valantic gave a function which can compute the confusion matrix variables.
在这里的链接中,Igor Valantic给出了一个可以计算混淆矩阵变量的函数。 it gives me an error at
correct_prediction = tf.nn.in_top_k(logits, labels, 1, name="correct_answers")
and the error is TypeError: DataType float32 for attr 'T' not in list of allowed values: int32, int64
它在
correct_prediction = tf.nn.in_top_k(logits, labels, 1, name="correct_answers")
给出了一个错误,错误是TypeError: DataType float32 for attr 'T' not in list of allowed values: int32, int64
correct_prediction = tf.nn.in_top_k(logits, labels, 1, name="correct_answers")
TypeError: DataType float32 for attr 'T' not in list of allowed values: int32, int64
I have tried typecasting logits to int32 inside function mentioned def evaluation(logits, labels)
, it gives another error at computing correct_prediction = ...
as TypeError:Input 'predictions' of 'InTopK' Op has type int32 that does not match expected type of float32
我已经尝试过将类型转换为int32内部函数提到的
def evaluation(logits, labels)
,它在计算correct_prediction = ...
时给出了另一个错误: correct_prediction = ...
作为TypeError:Input 'predictions' of 'InTopK' Op has type int32 that does not match expected type of float32
how to calculate this confusion matrix ? 如何计算这种混淆矩阵?
sess = tf.Session()
model = dimensions() # CNN input weights are calculated
data_train, data_test, label_train, label_test = load_data(files_test2,folder)
data_train, data_test, = reshapedata(data_train, data_test, model)
# input output placeholders
x = tf.placeholder(tf.float32, [model.BATCH_SIZE, model.input_width,model.input_height,model.input_depth]) # last column = 1
y_ = tf.placeholder(tf.float32, [model.BATCH_SIZE, model.No_Classes])
p_keep_conv = tf.placeholder("float")
#
y = mycnn(x,model, p_keep_conv)
# loss
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y, y_))
# train step
train_step = tf.train.AdamOptimizer(1e-3).minimize(cost)
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
true_positives, false_positives, true_negatives, false_negatives = evaluation(y,y_)
lossfun = np.zeros(STEPS)
sess.run(tf.global_variables_initializer())
for i in range(STEPS):
image_batch, label_batch = batchdata(data_train, label_train, model.BATCH_SIZE)
epoch_loss = 0
for j in range(model.BATCH_SIZE):
sess.run(train_step, feed_dict={x: image_batch, y_: label_batch, p_keep_conv:1.0})
c = sess.run( cost, feed_dict={x: image_batch, y_: label_batch, p_keep_conv: 1.0})
epoch_loss += c
lossfun[i] = epoch_loss
print('Epoch',i,'completed out of',STEPS,'loss:',epoch_loss )
TP,FP,TN,FN = sess.run([true_positives, false_positives, true_negatives, false_negatives], feed_dict={x: image_batch, y_: label_batch, p_keep_conv:1.0})
this is my code snippet 这是我的代码片段
You can simply use Tensorflow's confusion matrix . 您可以简单地使用Tensorflow的混淆矩阵 。 I assume
y
are your predictions, and you may or may not have num_classes
(which is optional) 我假设
y
是你的预测,你可能会或可能不会有num_classes
(这是可选)
y_ = placeholder_for_labels # for eg: [1, 2, 4]
y = mycnn(...) # for eg: [2, 2, 4]
confusion = tf.confusion_matrix(labels=y_, predictions=y, num_classes=num_classes)
If you print(confusion)
, you get 如果你
print(confusion)
,你得到
[[0 0 0 0 0]
[0 0 1 0 0]
[0 0 1 0 0]
[0 0 0 0 0]
[0 0 0 0 1]]
If print(confusion)
is not printing the confusion matrix, then use print(confusion.eval(session=sess))
. 如果
print(confusion)
没有打印混淆矩阵,则使用print(confusion.eval(session=sess))
。 Here sess
is the name of your TensorFlow session. 这里是
sess
是TensorFlow会话的名称。
import tensorflow as tf
y = [1, 2, 4]
y_ = [2, 2, 4]
con = tf.confusion_matrix(labels=y_, predictions=y )
sess = tf.Session()
with sess.as_default():
print(sess.run(con))
The output is : 输出是:
[[0 0 0 0 0]
[0 0 0 0 0]
[0 1 1 0 0]
[0 0 0 0 0]
[0 0 0 0 1]]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.