[英]How to handle scenario where numpy.where condition is unsatisfied?
我正在转换此数组:
x = np.array([[0, 0, 1], [1, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 0]])
到: [2, 0, 1, 0, 0]
。
基本上,我想返回每个子数组中前1
的索引。 但是,我的问题是我不知道如何处理没有1
。 我希望它在没有找到1
情况下返回0
(例如在我的示例中)。
下面的代码工作正常,但抛出IndexError: index 0 is out of bounds for axis 0 with size 0
对于我提到的场景, IndexError: index 0 is out of bounds for axis 0 with size 0
:
np.array([np.where(r == 1)[0][0] for r in x])
有什么简单的方法可以解决这个问题? 不必限于numpy.where。
我正在使用Python 3。
使用1s
mask
,然后在每行使用argmax
以获取第一个匹配索引以及any
索引以检查有效行(行中至少包含1
)-
mask = x==1
idx = np.where(mask.any(1), mask.argmax(1),0)
现在,所有False
上的argmax
将返回0
。 因此,这完全可以解决上述问题。 这样,我们可以简单地使用mask.argmax(1)
结果。 但在一般情况下,如果无效说明符称为invalid_val
不为0
,则可以在np.where
指定,例如-
idx = np.where(mask.any(1), mask.argmax(1),invalid_val)
另一种方法是获取掩码上的第一个匹配索引,然后索引到掩码中以查看任何索引值是否为False
并将其设置为0s
idx = mask.argmax(1)
idx[~mask[np.arange(len(idx)), idx]] = 0 # or invalid_val
对您的代码的一个简单修改就是向列表理解中添加一个条件:
np.array([np.where(r == 1)[0][0] if 1 in r else 0 for r in x])
# 23.1 µs ± 43.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
获得相同结果的更简洁,实质上更快的方法是:
np.argmax(x == 1, axis=1)
# 4.04 µs ± 45.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
或等效地:
np.argmin(x != 1, axis=1)
# 4.03 µs ± 13.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.