简体   繁体   中英

Find index of the maximum value in a numpy array

I have a numpy array called predictions as follows

array([[3.7839172e-06, 8.0308418e-09, 2.2542761e-06, 5.9392878e-08,
        5.3137046e-07, 1.7033290e-05, 1.7738441e-07, 1.0742254e-03,
        1.8656212e-06, 9.9890006e-01]], dtype=float32)

In order to get the index of the maximum value in this array, I used the following

np.where(prediction==prediction.max())

But the result I am getting showing index 0 also.

(array([0], dtype=int64), array([9], dtype=int64))

Does anyone know why is it showing index 0 also? Also how can I get just the index number instead of showing as (array([9], dtype=int64)

Use built-in function for it:

prediction.argmax()

output:

9

Also, that index 0 is the row number, so the max is at row 0 and column 9 .

The predictions array here is two dimensional. When you call np.where with only a condition, this is the same as calling np.asarray(condition).nonzero() which returns you the indices of the non-zero elements of prediction==prediction.max() which is a boolean array with the only non-zero element at (0,9) .

What you are looking for is the argmax function which will give you the index of the maximum value along an axis. You effectively only have one axis (2d but only one row) here so this should be fine.

As the other answers mentioned, you have a 2D array, so you end up with two indices. Since the array is just a row, the first index is always zero. You can bypass this in a number of ways:

  1. Use prediction.argmax() . The default axis argument is None , which means operate on a flattened array. Other options that will get you the same result are prediction.argmax(-1) (last axis) and prediction.argmax(1) (second axis). Keep in mind that you will only ever get the index of the first maximum this way. That's fine if you only ever expect to have one, or only need one.

  2. Use np.flatnonzero to get the linear indices similarly to the way you were doing:

     np.flatnonzero(perdiction == prediction.max())
  3. Use np.nonzero or np.where , but extract the axis you care about:

     np.nonzero(prediction == prediction.max())[1]
  4. ravel the array on input:

     np.where(prediction.ravel() == prediction.max())
  5. Do the same thing, but with np.squeeze :

     np.nonzero(prediction.squeeze() == prediction.max())

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM