[英]Why do we use numpy.argmax() to return an index from a numpy array of predictions?
Let me preface this by saying, I am very new to neural networks, and this is my first time using numpy, tensorflow, or keras. 首先,我要说一下:我对神经网络非常陌生,这是我第一次使用numpy,tensorflow或keras。
I wrote a neural network to recognize handwritten digits, using the MNIST data set. 我编写了一个神经网络,使用MNIST数据集识别手写数字。 I followed this tutorial by Sentdex and noticed he was using
print(np.argmax(predictions[0]))
to print the first index from the numpy array of predictions. 我按照Sentdex的教程学习,发现他正在使用
print(np.argmax(predictions[0]))
从numpy预测数组中打印第一个索引。
I tried running the program with that line replaced by print(predictions[i])
, (i was set to 0) but the output was not a number, it was: [2.1975785e-08 1.8658861e-08 2.8842608e-06 5.7113186e-05 1.2067199e-10 7.2511304e-09 1.6282028e-12 9.9993789e-01 1.3356166e-08 2.0409643e-06]
. 我尝试运行该程序,并将该行替换为
print(predictions[i])
,(i设置为0),但是输出不是数字,它是: [2.1975785e-08 1.8658861e-08 2.8842608e-06 5.7113186e-05 1.2067199e-10 7.2511304e-09 1.6282028e-12 9.9993789e-01 1.3356166e-08 2.0409643e-06]
。
My code than I'm confused about is: 我的代码比我感到困惑的是:
predictions = model.predict(x_test)
for i in range(10):
plt.imshow(x_test[i])
plt.show()
print("PREDICTION: ", predictions[i])
I read the numpy documentation for the argmax() function, and from what I understand, it takes in a x-dimensional array, converts it to a one-dimensional array, then returns the index of the largest value. 我阅读了argmax()函数的numpy文档,据我了解,它接受一个x维数组,将其转换为一维数组,然后返回最大值的索引。 The Keras documentation for model.predict() indicated that the function returns a numpy array of the networks predictions.
用于model.predict()的Keras文档指出,该函数返回网络预测的numpy数组。 So I don't understand why we have to use argmax() to properly print the prediction, because as I understand, it has a completely unrelated purpose.
因此,我不明白为什么我们必须使用argmax()正确打印预测,因为据我了解,它的目的完全不相关。
Sorry for the bad code formatting, I couldn't figure out how to properly insert multi line chunks of code into my post 对不起,不好的代码格式,我无法弄清楚如何在我的帖子中正确插入多行代码
If i understand well your question, the answer is pretty simple : 如果我很了解您的问题,答案非常简单:
I hope i'm clear ahah 我希望我很清楚啊
What any classification neural network outputs is a probability distribution over the class indices, meaning that the network assigns one probability to each class. 任何分类神经网络输出的都是在类索引上的概率分布,这意味着网络为每个类分配一个概率。 The sum of these probabilities is 1.0.
这些概率的总和为1.0。 Then the network is trained to assign the highest probability to the correct class, so to recover the class index from the probabilities you have to take the location (index) that has the maximum probability.
然后训练网络将最高概率分配给正确的类别,以便从概率中恢复类别索引,您必须获取具有最大概率的位置(索引)。 This is done with the
argmax
operation. 这是通过
argmax
操作完成的。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.