簡體   English   中英

TensorFlow估算器混淆矩陣

[英]TensorFlow Estimator Confusion Matrix

我是Tensorflow的新手,我剛剛運行了我的第一個神經網絡分類器,該分類器是從https://www.tensorflow.org/get_started/estimator獲得的代碼。 它工作成功,但只顯示了精度。 如何輸出混淆矩陣? 我只有2個標簽。 1和0。

這是代碼的最后一部分。 與鏈接相同。

  # Train model.
  classifier.train(input_fn=train_input_fn, steps=2000)

  # Define the test inputs
  test_input_fn = tf.estimator.inputs.numpy_input_fn(
      x={"x": np.array(test_set.data)},
      y=np.array(test_set.target),
      num_epochs=1,
      shuffle=True)

  # Evaluate accuracy.
  accuracy_score = classifier.evaluate(input_fn=test_input_fn)["accuracy"]

您可以使用tf.confusion_matrix生成混淆矩陣。 特別是,類似以下的內容應該起作用:

labels = list(test_set.target)
predictions = list(classifier.predict(input_fn=test_input_fn))
confusion_matrix = tf.confusion_matrix(labels, predictions)

可能還有更多。

labels = list(test_set[label_column])
raw_predictions = regressor.predict(input_fn=get_input_fn(test_set)
predictions = [p['class_ids'][0] for p in raw_predictions]
confusion_matrix = tf.confusion_matrix(labels, predictions)

原始預測輸出是標簽和概率的指示。 您需要獲取標簽。

您可能還需要添加以下內容:

with tf.Session():
    print('\nConfusion Matrix:\n', tf.Tensor.eval(confusion_matrix,feed_dict=None, session=None))

為了打印矩陣

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM