繁体   English   中英

如何处理numpy.where条件不满意的情况?

[英]How to handle scenario where numpy.where condition is unsatisfied?

我正在转换此数组:

x = np.array([[0, 0, 1], [1, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 0]])

到: [2, 0, 1, 0, 0]

基本上,我想返回每个子数组中前1的索引。 但是,我的问题是我不知道如何处理没有1 我希望它在没有找到1情况下返回0 (例如在我的示例中)。

下面的代码工作正常,但抛出IndexError: index 0 is out of bounds for axis 0 with size 0对于我提到的场景, IndexError: index 0 is out of bounds for axis 0 with size 0

np.array([np.where(r == 1)[0][0] for r in x])

有什么简单的方法可以解决这个问题? 不必限于numpy.where。

我正在使用Python 3。

使用1s mask ,然后在每行使用argmax以获取第一个匹配索引以及any索引以检查有效行(行中至少包含1 )-

mask = x==1
idx = np.where(mask.any(1), mask.argmax(1),0)

现在,所有False上的argmax将返回0 因此,这完全可以解决上述问题。 这样,我们可以简单地使用mask.argmax(1)结果。 但在一般情况下,如果无效说明符称为invalid_val不为0 ,则可以在np.where指定,例如-

idx = np.where(mask.any(1), mask.argmax(1),invalid_val)

另一种方法是获取掩码上的第一个匹配索引,然后索引到掩码中以查看任何索引值是否为False并将其设置为0s

idx = mask.argmax(1)
idx[~mask[np.arange(len(idx)), idx]] = 0 # or invalid_val

对您的代码的一个简单修改就是向列表理解中添加一个条件:

np.array([np.where(r == 1)[0][0] if 1 in r else 0 for r in x])
# 23.1 µs ± 43.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

获得相同结果的更简洁,实质上更快的方法是:

np.argmax(x == 1, axis=1)
# 4.04 µs ± 45.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

或等效地:

np.argmin(x != 1, axis=1)
# 4.03 µs ± 13.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM