简体   繁体   English

如何处理numpy.where条件不满意的情况?

[英]How to handle scenario where numpy.where condition is unsatisfied?

I am converting this array: 我正在转换此数组:

x = np.array([[0, 0, 1], [1, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 0]])

to: [2, 0, 1, 0, 0] . 到: [2, 0, 1, 0, 0]

Basically, I want to return the index of the first 1 in each sub-array. 基本上,我想返回每个子数组中前1的索引。 However, my problem is that I don't know how to handle the scenario where there is no 1 . 但是,我的问题是我不知道如何处理没有1 I want it to return 0 if 1 is not found (like in my example). 我希望它在没有找到1情况下返回0 (例如在我的示例中)。

The code below works fine but throws IndexError: index 0 is out of bounds for axis 0 with size 0 for the scenario I mentioned: 下面的代码工作正常,但抛出IndexError: index 0 is out of bounds for axis 0 with size 0对于我提到的场景, IndexError: index 0 is out of bounds for axis 0 with size 0

np.array([np.where(r == 1)[0][0] for r in x])

What is an easy way to handle this? 有什么简单的方法可以解决这个问题? It does not need to be restricted to numpy.where. 不必限于numpy.where。

I am using Python 3 by the way. 我正在使用Python 3。

Use mask of 1s and then argmax along each row to get the first matching index alongwith any to check for valid rows (rows with at least one 1 ) - 使用1s mask ,然后在每行使用argmax以获取第一个匹配索引以及any索引以检查有效行(行中至少包含1 )-

mask = x==1
idx = np.where(mask.any(1), mask.argmax(1),0)

Now, argmax on all False would return 0 . 现在,所有False上的argmax将返回0 So, that plays right into the hands of the stated problem. 因此,这完全可以解决上述问题。 As such, we can simply use mask.argmax(1) result. 这样,我们可以简单地使用mask.argmax(1)结果。 But in a general case, where the invalid specifier, let's call it invalid_val is not 0 , we can specify there inside np.where , like so - 但在一般情况下,如果无效说明符称为invalid_val不为0 ,则可以在np.where指定,例如-

idx = np.where(mask.any(1), mask.argmax(1),invalid_val)

Another method would be to get the first matching index on the mask and then index into the mask to see if any of the indexed values is False and set those as 0s - 另一种方法是获取掩码上的第一个匹配索引,然后索引到掩码中以查看任何索引值是否为False并将其设置为0s

idx = mask.argmax(1)
idx[~mask[np.arange(len(idx)), idx]] = 0 # or invalid_val

A simple modification to your code would be to add a condition to the list comprehension: 对您的代码的一个简单修改就是向列表理解中添加一个条件:

np.array([np.where(r == 1)[0][0] if 1 in r else 0 for r in x])
# 23.1 µs ± 43.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

A more concise and substantially faster way of obtaining the same result is: 获得相同结果的更简洁,实质上更快的方法是:

np.argmax(x == 1, axis=1)
# 4.04 µs ± 45.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

or, equivalently: 或等效地:

np.argmin(x != 1, axis=1)
# 4.03 µs ± 13.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM