簡體   English   中英

在Tensorboard中為Tensorflow中的混淆矩陣添加文本標簽

[英]Adding text labels to confusion matrix in Tensorflow for Tensorboard

我正在從Tensorflow的示例retrain.py中自定義代碼,以通過添加其他密集層,濾除,動量梯度下降等來訓練自己的圖像。

我想添加一個混淆矩陣,以Tensorboard所以我也跟着從第一個答案(傑羅德) 這個職位(我還曾試圖第二個答案,但面臨着一些調試問題),並添加了幾行add_evaluation_step功能。 所以現在看起來像:

def add_evaluation_step(result_tensor, ground_truth_tensor):

  with tf.name_scope('accuracy'):
    with tf.name_scope('correct_prediction'):
      prediction = tf.argmax(result_tensor, 1)
      correct_prediction = tf.equal(
          prediction, tf.argmax(ground_truth_tensor, 1))
    with tf.name_scope('accuracy'):
      evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  tf.summary.scalar('accuracy', evaluation_step)
  print('prediction shape :: {}'.format(ground_truth_tensor))

  #Add confusion matrix
  batch_confusion = tf.confusion_matrix(tf.argmax(ground_truth_tensor, 1), prediction,
                                             num_classes=7,
                                             name='batch_confusion')
  # Create an accumulator variable to hold the counts
  confusion = tf.Variable( tf.zeros([7,7], 
                                          dtype=tf.int32 ),
                                 name='confusion' )
  # Create the update op for doing a "+=" accumulation on the batch
  confusion_update = confusion.assign( confusion + batch_confusion )
  # Cast counts to float so tf.summary.image renormalizes to [0,255]
  confusion_image = tf.reshape( tf.cast( confusion_update, tf.float32),
                                 [1, 7, 7, 1])

  tf.summary.image('confusion',confusion_image)

  return evaluation_step, prediction

這給了我: 混淆矩陣

我的問題是如何將標簽添加到行(實際類)和列(預測類)。 得到類似的東西:

所需的混淆矩陣

Jerod的答案幾乎包含了您需要的所有內容,例如yauheni_selivonchyk關於如何向Tensorboard添加自定義圖像的其他答案

然后,只需將所有內容放在一起即可,即:

  1. 將繪制的圖像傳遞給摘要的實現方法(作為RGB數組)
  2. 實現一種將矩陣數據轉換為預設的混淆圖像的方法
  3. 定義正在運行的評估操作以獲取混淆矩陣數據(以及其他度量),並准備一個占位符和摘要以接收繪制的圖像
  4. 一起使用一切

1.實現將繪圖圖像傳遞給摘要的方法

import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
import tensorflow as tf

# Inspired by yauheni_selivonchyk on SO (https://stackoverflow.com/a/42815564/624547)

def get_figure(figsize=(10, 10), dpi=300):
    """
    Return a pyplot figure
    :param figsize:
    :param dpi:
    :return:
    """
    fig = plt.figure(num=0, figsize=figsize, dpi=dpi)
    fig.clf()
    return fig


def fig_to_rgb_array(fig, expand=True):
    """
    Convert figure into a RGB array
    :param fig:         PyPlot Figure
    :param expand:      Flag to expand
    :return:            RGB array
    """
    fig.canvas.draw()
    buf = fig.canvas.tostring_rgb()
    ncols, nrows = fig.canvas.get_width_height()
    shape = (nrows, ncols, 3) if not expand else (1, nrows, ncols, 3)
    return np.fromstring(buf, dtype=np.uint8).reshape(shape)


def figure_to_summary(fig, summary, place_holder):
    """
    Convert figure into TF summary
    :param fig:             Figure
    :param summary:         Summary to eval
    :param place_holder:    Summary image placeholder
    :return:                Summary
    """
    image = fig_to_rgb_array(fig)
    return summary.eval(feed_dict={place_holder: image})

2.將矩陣數據轉換為預設的混淆圖像

(這是一個示例,但這取決於您想要的)

def confusion_matrix_to_image_summary(confusion_matrix, summary, place_holder, 
                                      list_classes, figsize=(9, 9)):
    """
    Plot confusion matrix and return as TF summary
    :param matrix:          Confusion matrix (N x N)
    :param filename:        Filename
    :param list_classes:    List of classes (N)
    :param figsize:         Pyplot figsize for the confusion image
    :return:                /
    """
    fig = get_figure(figsize=(9, 9))
    df = pd.DataFrame(confusion_matrix, index=list_classes, columns=list_classes)
    ax = sns.heatmap(df, annot=True, fmt='.0%')
    # Whatever embellishments you want:
    plt.title('Confusion matrix')
    plt.xticks(rotation=90)
    plt.yticks(rotation=0)
    image_sum = figure_to_summary(fig, summary, place_holder)
    return image_sum

3.定義評估操作並准備占位符

# Inspired by Jerod's answer on SO (https://stackoverflow.com/a/42857070/624547)    
def add_evaluation_step(result_tensor, ground_truth_tensor, num_classes, confusion_matrix_figsize=(9, 9)):
    """
    Sets up the evaluation operations, computing the running accuracy and confusion image
    :param result_tensor:               Output tensor
    :param ground_truth_tensor:         Target class tensor
    :param num_classes:                 Number of classes
    :param confusion_matrix_figsize:    Pyplot figsize for the confusion image
    :return:                            TF operations, summaries and placeholders (see usage below)
    """
    scope = "evaluation"
    with tf.name_scope(scope):
        predictions = tf.argmax(result_tensor, 1, name="prediction")

        # Streaming accuracy (lookup and update tensors):
        accuracy, accuracy_update = tf.metrics.accuracy(ground_truth_tensor, predictions, name='accuracy')
        # Per-batch confusion matrix:
        batch_confusion = tf.confusion_matrix(ground_truth_tensor, predictions, num_classes=num_classes,
                                              name='batch_confusion')

        # Aggregated confusion matrix:
        confusion_matrix = tf.Variable(tf.zeros([num_classes, num_classes], dtype=tf.int32),
                                       name='confusion')
        confusion_update = confusion_matrix.assign(confusion_matrix + batch_confusion)

        # We suppose each batch contains a complete class, to directly normalize by its size:
        evaluate_streaming_metrics_op = tf.group(accuracy_update, confusion_update)

        # Confusion image from matrix (need to extend dims + cast to float so tf.summary.image renormalizes to [0,255]):
        confusion_image = tf.reshape(tf.cast(confusion_update, tf.float32), [1, num_classes, num_classes, 1])

        # Summaries:
        tf.summary.scalar('accuracy', accuracy, collections=[scope])
        summary_op = tf.summary.merge_all(scope)

        # Preparing placeholder for confusion image (so that we can pass the plotted image to it):
        #      (we basically pre-allocate a plot figure and pass its RGB array to a placeholder)
        confusion_image_placeholder = tf.placeholder(tf.uint8,
                                                     fig_to_rgb_array(get_figure(figsize=confusion_matrix_figsize)).shape)
        confusion_image_summary = tf.summary.image('confusion_image', confusion_image_placeholder)

    # Isolating all the variables stored by the metric operations:
    running_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope)
    running_vars += tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=scope)

    # Initializer op to start/reset running variables
    reset_streaming_metrics_op = tf.variables_initializer(var_list=running_vars)

    return evaluate_streaming_metrics_op, reset_streaming_metrics_op, summary_op, confusion_image_summary, \
           confusion_image_placeholder, confusion_image

4.放在一起

一個快速的示例,說明如何使用它,盡管它需要根據您的培訓過程等進行調整。

classes = ["obj1", "obj2", "obj3"]
num_classes = len(classes)
model = your_network(...)

evaluate_streaming_metrics_op, reset_streaming_metrics_op, summary_op,
confusion_image_summary,  confusion_image_placeholder, confusion_image = \
add_evaluation_step(model.output, model.target, num_classes)

def evaluate(session, model, eval_data_gen):
    """
    Evaluate the model
    :param session:         TF session
    :param eval_data_gen:   Data to evaluate on
    :return:                Evaluation summaries for Tensorboard
    """
    # Resetting streaming vars:
    session.run(reset_streaming_metrics_op)

    # Evaluating running ops over complete eval dataset, e.g.:
    for batch in eval_data_gen:
        feed_dict = {model.inputs: batch}
        session.run(evaluate_streaming_metrics_op, feed_dict=feed_dict)

    # Obtaining the final results:
    summary_str, confusion_results = session.run([summary_op, confusion_image])

    # Converting confusion data into plot into summary:
    confusion_img_str = confusion_matrix_to_image_summary(
        confusion_results[0,:,:,0], confusion_image_summary, confusion_image_placeholder, classes)
    summary_str += confusion_img_str

    return summary_str # to be given to a SummaryWriter

遵循MLNINJA的回答,不僅幫助我獲得了標簽,而且還獲得了漂亮的實時流式可視化效果。 這是我的方法。 首先,我將此函數寫入retrain.py

from textwrap import wrap
import itertools
import matplotlib
import tfplot
import os
import re

def plot_confusion_matrix(correct_labels, predict_labels,labels,session, title='Confusion matrix', tensor_name = 'MyFigure/image', normalize=False):
  conf = tf.contrib.metrics.confusion_matrix(correct_labels, predict_labels)

  cm=session.run(conf)

  if normalize:
    cm = cm.astype('float')*10 / cm.sum(axis=1)[:, np.newaxis]
    cm = np.nan_to_num(cm, copy=True)
    cm = cm.astype('int')

  np.set_printoptions(precision=2)

  fig = matplotlib.figure.Figure(figsize=(7, 7), dpi=320, facecolor='w', edgecolor='k')
  ax = fig.add_subplot(1, 1, 1)
  im = ax.imshow(cm, cmap='Oranges')

  classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in labels]
  classes = ['\n'.join(wrap(l, 40)) for l in classes]

  tick_marks = np.arange(len(classes))

  ax.set_xlabel('Predicted', fontsize=7)
  ax.set_xticks(tick_marks)
  c = ax.set_xticklabels(classes, fontsize=10, rotation=-90,  ha='center')
  ax.xaxis.set_label_position('bottom')
  ax.xaxis.tick_bottom()

  ax.set_ylabel('True Label', fontsize=7)
  ax.set_yticks(tick_marks)
  ax.set_yticklabels(classes, fontsize=10, va ='center')
  ax.yaxis.set_label_position('left')
  ax.yaxis.tick_left()

  for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    ax.text(j, i, format(cm[i, j], 'd') if cm[i,j]!=0 else '.', horizontalalignment="center", fontsize=6, verticalalignment='center', color= "black")

  fig.set_tight_layout(True)
  summary = tfplot.figure.to_summary(fig, tag=tensor_name)
  return summary

在我的retrain.py 函數版本中,首先在1227行創建一個摘要conf__writer ,用於混淆矩陣。 然后,在每個評估步驟都被調用的if(第1261行)子句中調用該函數(在第1287行),最后,在第1288行將摘要寫入摘要目錄。

注意add_evaluation_step函數也已被修改,以返回用於地面真實輸入的張量。 在第1278行中,運行該命令以獲得地面真實輸入的數組,該數組被輸入到plot_confusion_matrix函數。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM