简体   繁体   中英

how does numpy.where work?

I can understand following numpy behavior.

>>> a
array([[ 0. ,  0. ,  0. ],
       [ 0. ,  0.7,  0. ],
       [ 0. ,  0.3,  0.5],
       [ 0.6,  0. ,  0.8],
       [ 0.7,  0. ,  0. ]])
>>> argmax_overlaps = a.argmax(axis=1)
>>> argmax_overlaps
array([0, 1, 2, 2, 0])
>>> max_overlaps = a[np.arange(5),argmax_overlaps]
>>> max_overlaps
array([ 0. ,  0.7,  0.5,  0.8,  0.7])
>>> gt_argmax_overlaps = a.argmax(axis=0)
>>> gt_argmax_overlaps
array([4, 1, 3])
>>> gt_max_overlaps = a[gt_argmax_overlaps,np.arange(a.shape[1])]
>>> gt_max_overlaps
array([ 0.7,  0.7,  0.8])
>>> gt_argmax_overlaps = np.where(a == gt_max_overlaps)
>>> gt_argmax_overlaps
(array([1, 3, 4]), array([1, 2, 0]))

I understood 0.7, 0.7 and 0.8 is a[1,1],a[3,2] and a[4,0] so I got the tuple (array[1,3,4] and array[1,2,0]) each array of which composed of 0th and 1st indices of those three elements. I then tried other examples to see my understanding is correct.

>>> np.where(a == [0.3])
(array([2]), array([1]))

0.3 is in a[2,1] so the outcome looks as I expected. Then I tried

>>> np.where(a == [0.3, 0.5])
(array([], dtype=int64),)

?? I expected to see (array([2,2]),array([2,3])). Why do I see the output above?

>>> np.where(a == [0.7, 0.7, 0.8])
(array([1, 3, 4]), array([1, 2, 0]))
>>> np.where(a == [0.8,0.7,0.7])
(array([1]), array([1]))

I can't understand the second result either. Could someone please explain it to me? Thanks.

The first thing to realize is that np.where(a == [whatever]) is just showing you the indices where a == [whatever] is True. So you can get a hint by looking at the value of a == [whatever] . In your case that "works":

>>> a == [0.7, 0.7, 0.8]
array([[False, False, False],
       [False,  True, False],
       [False, False, False],
       [False, False,  True],
       [ True, False, False]], dtype=bool)

You aren't getting what you think you are. You think that is asking for the indices of each element separately, but instead it's getting the positions where the values match at the same position in the row . Basically what this comparison is doing is saying "for each row, tell me whether the first element is 0.7, whether the second is 0.7, and whether the third is 0.8". It then returns the indices of those matching positions. In other words, the comparison is done between entire rows, not just individual values. For your last example:

>>> a == [0.8,0.7,0.7]
array([[False, False, False],
       [False,  True, False],
       [False, False, False],
       [False, False, False],
       [False, False, False]], dtype=bool)

You now get a different result. It's not asking for "the indices where a has value 0.8", it's asking for only the indices where there is a 0.8 at the beginning of the row -- and likewise a 0.7 in either of the later two positions.

This type of row-wise comparison can only be done if the value you compare against matches the shape of a single row of a . So when you try it with a two-element list, it returns an empty set, because there it is trying to compare the list as a scalar value against individual values in your array.

The upshot is that you can't use == on a list of values and expect it to just tell you where any of the values occurs. The equality will match by value and position (if the value you compare against is the same shape as a row of your array), or it will try to compare the whole list as a scalar (if the shape doesn't match). If you want to search for the values independently, you need to do something like what Khris suggested in a comment:

np.where((a==0.3)|(a==0.5))

That is, you need to make two (or more) separate comparisons against separate values, not a single comparison against a list of values.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM