簡體   English   中英

在numpy數組中查找最大值的索引

[英]Find index of the maximum value in a numpy array

我有一個名為predictions的 numpy 數組,如下所示

array([[3.7839172e-06, 8.0308418e-09, 2.2542761e-06, 5.9392878e-08,
        5.3137046e-07, 1.7033290e-05, 1.7738441e-07, 1.0742254e-03,
        1.8656212e-06, 9.9890006e-01]], dtype=float32)

為了獲得此數組中最大值的索引,我使用了以下內容

np.where(prediction==prediction.max())

但是我得到的結果也顯示索引 0。

(array([0], dtype=int64), array([9], dtype=int64))

有誰知道為什么它也顯示索引 0? 另外我怎樣才能得到索引號而不是顯示為 (array([9], dtype=int64)

為它使用內置函數:

prediction.argmax()

輸出:

9

此外,該索引0是行號,因此最大值位於第0行和第9列。

這里的predictions數組是二維的。 當您僅使用一個條件調用np.where時,這與調用np.asarray(condition).nonzero() ,后者返回prediction==prediction.max()的非零元素的索引,這是一個在(0,9)處具有唯一非零元素的布爾數組。

您正在尋找的是argmax函數,它將為您提供沿軸的最大值的索引。 您實際上只有一個軸(2d 但只有一行),所以這應該沒問題。

正如其他答案所提到的,您有一個二維數組,因此最終有兩個索引。 由於數組只是一行,因此第一個索引始終為零。 您可以通過多種方式繞過它:

  1. 使用prediction.argmax() 默認axis參數是None ,這意味着對扁平數組進行操作。 其他可以獲得相同結果的選項是prediction.argmax(-1) (最后一個軸)和prediction.argmax(1) (第二個軸)。 請記住,您只能通過這種方式獲得第一個最大值的索引。 如果您只希望擁有一個,或者只需要一個,那很好。

  2. 使用np.flatnonzero以類似於您正在執行的方式獲取線性索引:

     np.flatnonzero(perdiction == prediction.max())
  3. 使用np.nonzeronp.where ,但提取您關心的軸:

     np.nonzero(prediction == prediction.max())[1]
  4. 在輸入上ravel數組:

     np.where(prediction.ravel() == prediction.max())
  5. 做同樣的事情,但使用np.squeeze

     np.nonzero(prediction.squeeze() == prediction.max())

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM