![](/img/trans.png)
[英]What is the fastest way to extract given rows and columns from a Numpy ndarray?
[英]Extract numpy rows by given condition
我有如下的 numpy 数组。
import numpy as np
data = np.array([[0,0,0,4],
[3,0,5,0],
[8,9,5,3]])
print (data)
我必须只提取前三个元素不全为零的那些行,预期结果如下:
result = np.array([[3,0,5,0],
[8,9,5,3]])
我试过:
res = [l for l in data if l[:3].sum() !=0]
print (res)
它给出了结果。 但是,正在寻找更好的、麻木的方法。
如果您的数组可以包含负数,则sum
有点不可靠,但any
将始终有效:
result = data[data[:, :3].any(1)]
你说
前三个元素不全为零
所以一个解决方案是
import numpy as np
data = np.array([[0,0,0,4],
[3,0,5,0],
[8,9,5,3]])
data[~np.all(data[:, :3] == 0, axis=1), :]
我将尝试通过我的回答来解释我是如何看待这些问题的。
第一步:定义一个函数,该函数返回一个布尔值,指示这是否是一个好行。 为此,我使用 np.any,它检查是否有任何条目为“真”(对于整数,真为非零)。
import numpy as np
v1 = np.array([1, 1, 1, 0])
v2 = np.array([0, 0, 0, 1])
good_row = lambda v: np.any(v[:3])
good_row(v1)
Out[28]: True
good_row(v2)
Out[29]: False
第二步:我将其应用于所有行,并获得一个掩蔽向量。 为此,可以在“np.any”中使用“axis”关键字,它将根据轴值将其应用于列或行。
np.any(data[:, :3], axis=1)
Out[32]: array([False, True, True])
最后一步:我将其与索引结合起来,将其全部包装起来。
rows_inds = np.any(data[:, :3], axis=1)
data[rows_inds]
Out[37]:
array([[3, 0, 5, 0],
[8, 9, 5, 3]])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.