[英]Keras - How to use argmax for predictions
I have 3 categories of classes Tree, Stump, Ground
. 我有3类类别
Tree, Stump, Ground
。 I've made a list for these categories: 我列出了以下类别的清单:
CATEGORIES = ["Tree", "Stump", "Ground"]
When i print my prediction, it gives me the output of 当我打印我的预测时,它给我输出
[[0. 1. 0.]]
I've read up about numpy's Argmax but I'm not entirely sure how to use it in this case. 我已经阅读了有关numpy的Argmax的信息,但是我不确定在这种情况下如何使用它。
I've tried using 我试过使用
print(np.argmax(prediction))
But that gives me the output of 1
. 但这给了我
1
的输出。 That's great but I would like to find out what's the index of 1
and then print out the Category instead of the highest value. 很好,但是我想找出
1
的索引是什么,然后打印出Category而不是最大值。
import cv2
import tensorflow as tf
import numpy as np
CATEGORIES = ["Tree", "Stump", "Ground"]
def prepare(filepath):
IMG_SIZE = 150 # This value must be the same as the value in Part1
img_array = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
new_array = cv2.resize(img_array, (IMG_SIZE, IMG_SIZE))
return new_array.reshape(-1, IMG_SIZE, IMG_SIZE, 1)
# Able to load a .model, .h3, .chibai and even .dog
model = tf.keras.models.load_model("models/test.model")
prediction = model.predict([prepare('image.jpg')])
print("Predictions:")
print(prediction)
print(np.argmax(prediction))
I expect my prediction to show me: 我希望我的预测能告诉我:
Predictions:
[[0. 1. 0.]]
Stump
Thanks for reading :) I appreciate any help at all. 感谢您的阅读:)非常感谢您的帮助。
You just have to index categories with the result of np.argmax
: 您只需要使用
np.argmax
的结果索引类别:
pred_name = CATEGORIES[np.argmax(prediction)]
print(pred_name)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.