简体   繁体   中英

Tensorflow Confusion Matrix in TensorBoard

I want to have a visual of confusion matrix in tensorboard. To do this, I am modifying Evaluation example of Tensorflow Slim: https://github.com/tensorflow/models/blob/master/slim/eval_image_classifier.py

In this example code, Accuracy already provided but it is not possible to add "confusion matrix" metric directly because it is not streaming.

What is difference between streaming metrics and non-streaming ones?

Therefore, I tried to add it like this:

c_matrix = slim.metrics.confusion_matrix(predictions, labels)

#These operations needed for image summary
c_matrix = tf.cast(c_matrix, uint8)
c_matrix = tf.expand_dims(c_matrix, 2)
c_matrix = tf.expand_dims(c_matrix, 0)

op = tf.image_summary("confusion matrix", c_matrix, collections=[])
tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

This creates an image in tensorboard but probably there is a formatting problem. Matrix should be normalized between 0-1 so that It produces meaningful image.

How can I produce a meaningful confusion matrix? How can I deal with multi batch evaluation process?

Here is something I have put together That works reasonably well. Still need to adjust a few things like the tick placements etc.

混淆矩阵作为Tensorflow中的图像

Here is the function that will pretty much do everything for you.

from textwrap import wrap
import re
import itertools
import tfplot
import matplotlib
import numpy as np
from sklearn.metrics import confusion_matrix



def plot_confusion_matrix(correct_labels, predict_labels, labels, title='Confusion matrix', tensor_name = 'MyFigure/image', normalize=False):
''' 
Parameters:
    correct_labels                  : These are your true classification categories.
    predict_labels                  : These are you predicted classification categories
    labels                          : This is a lit of labels which will be used to display the axix labels
    title='Confusion matrix'        : Title for your matrix
    tensor_name = 'MyFigure/image'  : Name for the output summay tensor

Returns:
    summary: TensorFlow summary 

Other itema to note:
    - Depending on the number of category and the data , you may have to modify the figzie, font sizes etc. 
    - Currently, some of the ticks dont line up due to rotations.
'''
cm = confusion_matrix(correct_labels, predict_labels, labels=labels)
if normalize:
    cm = cm.astype('float')*10 / cm.sum(axis=1)[:, np.newaxis]
    cm = np.nan_to_num(cm, copy=True)
    cm = cm.astype('int')

np.set_printoptions(precision=2)
###fig, ax = matplotlib.figure.Figure()

fig = matplotlib.figure.Figure(figsize=(7, 7), dpi=320, facecolor='w', edgecolor='k')
ax = fig.add_subplot(1, 1, 1)
im = ax.imshow(cm, cmap='Oranges')

classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in labels]
classes = ['\n'.join(wrap(l, 40)) for l in classes]

tick_marks = np.arange(len(classes))

ax.set_xlabel('Predicted', fontsize=7)
ax.set_xticks(tick_marks)
c = ax.set_xticklabels(classes, fontsize=4, rotation=-90,  ha='center')
ax.xaxis.set_label_position('bottom')
ax.xaxis.tick_bottom()

ax.set_ylabel('True Label', fontsize=7)
ax.set_yticks(tick_marks)
ax.set_yticklabels(classes, fontsize=4, va ='center')
ax.yaxis.set_label_position('left')
ax.yaxis.tick_left()

for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    ax.text(j, i, format(cm[i, j], 'd') if cm[i,j]!=0 else '.', horizontalalignment="center", fontsize=6, verticalalignment='center', color= "black")
fig.set_tight_layout(True)
summary = tfplot.figure.to_summary(fig, tag=tensor_name)
return summary
#

And here is the rest of the code that you will need to call this functions.

 ''' confusion matrix summaries ''' img_d_summary_dir = os.path.join(checkpoint_dir, "summaries", "img") img_d_summary_writer = tf.summary.FileWriter(img_d_summary_dir, sess.graph) img_d_summary = plot_confusion_matrix(correct_labels, predict_labels, labels, tensor_name='dev/cm') img_d_summary_writer.add_summary(img_d_summary, current_step) 

Confuse away!!!

Here's how I produced and displayed a "streaming" confusion matrix for test code (returned test_op is evaluated for each batch to test on).

def _get_streaming_metrics(prediction,label,num_classes):

    with tf.name_scope("test"):
        # the streaming accuracy (lookup and update tensors)
        accuracy,accuracy_update = tf.metrics.accuracy(label, prediction, 
                                               name='accuracy')
        # Compute a per-batch confusion
        batch_confusion = tf.confusion_matrix(label, prediction,
                                             num_classes=num_classes,
                                             name='batch_confusion')
        # Create an accumulator variable to hold the counts
        confusion = tf.Variable( tf.zeros([num_classes,num_classes], 
                                          dtype=tf.int32 ),
                                 name='confusion' )
        # Create the update op for doing a "+=" accumulation on the batch
        confusion_update = confusion.assign( confusion + batch_confusion )
        # Cast counts to float so tf.summary.image renormalizes to [0,255]
        confusion_image = tf.reshape( tf.cast( confusion, tf.float32),
                                  [1, num_classes, num_classes, 1])
        # Combine streaming accuracy and confusion matrix updates in one op
        test_op = tf.group(accuracy_update, confusion_update)

        tf.summary.image('confusion',confusion_image)
        tf.summary.scalar('accuracy',accuracy)

    return test_op,accuracy,confusion

After you process all the data batches by running test_op , you can simply look up the final confusion matrix (within your session) by confusion.eval() or sess.eval(confusion) if you prefer.

Here is something that works with tf.contrib.metrics.MetricSpec (when you use Estimator). It is inspired from Jerod's answer and the metric_op.py source file. You get a streamed confusion matrix with percentages :

from tensorflow.python.framework import ops,dtypes
from tensorflow.python.ops import array_ops,variables

def _createLocalVariable(name, shape, collections=None, 
validate_shape=True,
              dtype=dtypes.float32):
  """Creates a new local variable.
  """
  # Make sure local variables are added to 
  # tf.GraphKeys.LOCAL_VARIABLES
  collections = list(collections or [])
  collections += [ops.GraphKeys.LOCAL_VARIABLES]
  return variables.Variable(
  initial_value=array_ops.zeros(shape, dtype=dtype),
  name=name,
  trainable=False,
  collections=collections,
  validate_shape=validate_shape)

def streamingConfusionMatrix(label, prediction, 
weights=None,num_classes=None):
  """
  Compute a streaming confusion matrix
  :param label: True labels
  :param prediction: Predicted labels
  :param weights: (Optional) weights (unused)
  :param num_classes: Number of labels for the confusion matrix
  :return: (percentConfusionMatrix,updateOp)
  """
  # Compute a per-batch confusion

  batch_confusion = tf.confusion_matrix(label, prediction,
                                    num_classes=num_classes,
                                    name='batch_confusion')

  count = _createLocalVariable(None,(),dtype=tf.int32)
  confusion = _createLocalVariable('streamConfusion',[num_classes, 
  num_classes],dtype=tf.int32)

  # Create the update op for doing a "+=" accumulation on the batch
  countUpdate = count.assign(count + tf.reduce_sum(batch_confusion))
  confusionUpdate = confusion.assign(confusion + batch_confusion)

  updateOp = tf.group(confusionUpdate,countUpdate)

  percentConfusion = 100 * tf.truediv(confusion,count)

  return percentConfusion,updateOp

You can then use it as evaluation metric in the following way:

from tensorflow.contrib import learn,metrics
#[...]

evalMetrics = {'accuracy': 
learn.MetricSpec(metric_fn=metrics.streaming_accuracy),
               'confusionMatrix':learn.MetricSpec(metric_fn=
                                                  lambda 
label,prediction,weights=None:                         
streamingConfusionMatrix(                                                    
label,prediction,weights,num_classes=nLabels))}

I suggest you use numpy.set_printoptions(precision=2,suppress=True) to print it out.

Re: your image not being meaningful - according to the docs for the tf.summary.image , for uint8 values are unchanged (won't be normalized), and are interpreted in range [0, 255]. Have you tried re-normalizing your image to [0,255] instead of [0,1]?

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM