[英]Find index of the maximum value in a numpy array
我有一個名為predictions
的 numpy 數組,如下所示
array([[3.7839172e-06, 8.0308418e-09, 2.2542761e-06, 5.9392878e-08,
5.3137046e-07, 1.7033290e-05, 1.7738441e-07, 1.0742254e-03,
1.8656212e-06, 9.9890006e-01]], dtype=float32)
為了獲得此數組中最大值的索引,我使用了以下內容
np.where(prediction==prediction.max())
但是我得到的結果也顯示索引 0。
(array([0], dtype=int64), array([9], dtype=int64))
有誰知道為什么它也顯示索引 0? 另外我怎樣才能得到索引號而不是顯示為 (array([9], dtype=int64)
為它使用內置函數:
prediction.argmax()
輸出:
9
此外,該索引0
是行號,因此最大值位於第0
行和第9
列。
這里的predictions
數組是二維的。 當您僅使用一個條件調用np.where
時,這與調用np.asarray(condition).nonzero()
,后者返回prediction==prediction.max()
的非零元素的索引,這是一個在(0,9)
處具有唯一非零元素的布爾數組。
您正在尋找的是argmax
函數,它將為您提供沿軸的最大值的索引。 您實際上只有一個軸(2d 但只有一行),所以這應該沒問題。
正如其他答案所提到的,您有一個二維數組,因此最終有兩個索引。 由於數組只是一行,因此第一個索引始終為零。 您可以通過多種方式繞過它:
使用prediction.argmax()
。 默認axis
參數是None
,這意味着對扁平數組進行操作。 其他可以獲得相同結果的選項是prediction.argmax(-1)
(最后一個軸)和prediction.argmax(1)
(第二個軸)。 請記住,您只能通過這種方式獲得第一個最大值的索引。 如果您只希望擁有一個,或者只需要一個,那很好。
使用np.flatnonzero
以類似於您正在執行的方式獲取線性索引:
np.flatnonzero(perdiction == prediction.max())
使用np.nonzero
或np.where
,但提取您關心的軸:
np.nonzero(prediction == prediction.max())[1]
在輸入上ravel
數組:
np.where(prediction.ravel() == prediction.max())
做同樣的事情,但使用np.squeeze
:
np.nonzero(prediction.squeeze() == prediction.max())
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.