简体   繁体   English

numpy.where 是如何工作的?

[英]how does numpy.where work?

I can understand following numpy behavior.我可以理解以下 numpy 行为。

>>> a
array([[ 0. ,  0. ,  0. ],
       [ 0. ,  0.7,  0. ],
       [ 0. ,  0.3,  0.5],
       [ 0.6,  0. ,  0.8],
       [ 0.7,  0. ,  0. ]])
>>> argmax_overlaps = a.argmax(axis=1)
>>> argmax_overlaps
array([0, 1, 2, 2, 0])
>>> max_overlaps = a[np.arange(5),argmax_overlaps]
>>> max_overlaps
array([ 0. ,  0.7,  0.5,  0.8,  0.7])
>>> gt_argmax_overlaps = a.argmax(axis=0)
>>> gt_argmax_overlaps
array([4, 1, 3])
>>> gt_max_overlaps = a[gt_argmax_overlaps,np.arange(a.shape[1])]
>>> gt_max_overlaps
array([ 0.7,  0.7,  0.8])
>>> gt_argmax_overlaps = np.where(a == gt_max_overlaps)
>>> gt_argmax_overlaps
(array([1, 3, 4]), array([1, 2, 0]))

I understood 0.7, 0.7 and 0.8 is a[1,1],a[3,2] and a[4,0] so I got the tuple (array[1,3,4] and array[1,2,0]) each array of which composed of 0th and 1st indices of those three elements.我明白 0.7, 0.7 和 0.8 是 a[1,1],a[3,2] 和 a[4,0] 所以我得到了元组(array[1,3,4] and array[1,2,0])每个数组由这三个元素的第 0 个和第 1 个索引组成。 I then tried other examples to see my understanding is correct.然后我尝试了其他示例以查看我的理解是否正确。

>>> np.where(a == [0.3])
(array([2]), array([1]))

0.3 is in a[2,1] so the outcome looks as I expected. 0.3 在 a[2,1] 中,所以结果看起来和我预期的一样。 Then I tried然后我试过了

>>> np.where(a == [0.3, 0.5])
(array([], dtype=int64),)

?? ?? I expected to see (array([2,2]),array([2,3])).我希望看到 (array([2,2]),array([2,3]))。 Why do I see the output above?为什么我看到上面的输出?

>>> np.where(a == [0.7, 0.7, 0.8])
(array([1, 3, 4]), array([1, 2, 0]))
>>> np.where(a == [0.8,0.7,0.7])
(array([1]), array([1]))

I can't understand the second result either.我也无法理解第二个结果。 Could someone please explain it to me?有人可以向我解释一下吗? Thanks.谢谢。

The first thing to realize is that np.where(a == [whatever]) is just showing you the indices where a == [whatever] is True.要意识到的第一件事是np.where(a == [whatever])只是向您显示a == [whatever]为 True 的索引。 So you can get a hint by looking at the value of a == [whatever] .因此,您可以通过查看a == [whatever]的值来获得提示。 In your case that "works":在您的情况下,“有效”:

>>> a == [0.7, 0.7, 0.8]
array([[False, False, False],
       [False,  True, False],
       [False, False, False],
       [False, False,  True],
       [ True, False, False]], dtype=bool)

You aren't getting what you think you are.你没有得到你认为的。 You think that is asking for the indices of each element separately, but instead it's getting the positions where the values match at the same position in the row .您认为这是分别要求每个元素的索引,而是获取值在 row 中相同位置匹配的位置 Basically what this comparison is doing is saying "for each row, tell me whether the first element is 0.7, whether the second is 0.7, and whether the third is 0.8".基本上这个比较是在说“对于每一行,告诉我第一个元素是否为 0.7,第二个是否为 0.7,第三个是否为 0.8”。 It then returns the indices of those matching positions.然后它返回那些匹配位置的索引。 In other words, the comparison is done between entire rows, not just individual values.换句话说,比较是在整行之间进行的,而不仅仅是单个值。 For your last example:对于您的最后一个示例:

>>> a == [0.8,0.7,0.7]
array([[False, False, False],
       [False,  True, False],
       [False, False, False],
       [False, False, False],
       [False, False, False]], dtype=bool)

You now get a different result.您现在会得到不同的结果。 It's not asking for "the indices where a has value 0.8", it's asking for only the indices where there is a 0.8 at the beginning of the row -- and likewise a 0.7 in either of the later two positions.它不要求“ a值为 0.8 的索引”,它只要求在行开头有 0.8 的索引 - 同样在后面两个位置中的任何一个位置都为 0.7。

This type of row-wise comparison can only be done if the value you compare against matches the shape of a single row of a .如果比较的值的单排的形状相匹配这种类型的逐行比较的只能做a So when you try it with a two-element list, it returns an empty set, because there it is trying to compare the list as a scalar value against individual values in your array.因此,当您使用二元素列表尝试它时,它返回一个空集,因为它试图将列表作为标量值与数组中的单个值进行比较。

The upshot is that you can't use == on a list of values and expect it to just tell you where any of the values occurs.结果是您不能在值列表上使用==并期望它只是告诉您任何值出现的位置。 The equality will match by value and position (if the value you compare against is the same shape as a row of your array), or it will try to compare the whole list as a scalar (if the shape doesn't match).相等将按值和位置匹配(如果您比较的值与数组中的一行形状相同),或者它将尝试将整个列表作为标量进行比较(如果形状不匹配)。 If you want to search for the values independently, you need to do something like what Khris suggested in a comment:如果要独立搜索值,则需要执行类似于 Khris 在评论中建议的操作:

np.where((a==0.3)|(a==0.5))

That is, you need to make two (or more) separate comparisons against separate values, not a single comparison against a list of values.也就是说,您需要对单独的值进行两次(或更多)单独的比较,而不是对值列表进行一次比较。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM