[英]Filtering rows of numpy array based on whether row elements are in another array
我有一个数组group
,它是 Nx2:
array([[ 1, 6],
[ 1, 0],
[ 2, 1],
...,
[40196, 40197],
[40196, 40198],
[40196, 40199]], dtype=uint32)
另一个数组selection
是 (M,):
array([3216, 3217, 3218, ..., 8039])
我想创建一个新数组,其中包含两个元素都在selection
的group
所有行。 我是这样做的:
np.array([(i,j) for (i,j) in group if i in selection and j in selection])
这有效,但我知道必须有一种更有效的方法来利用一些 numpy 函数。
您可以使用np.isin
获取与group
形状相同的布尔数组,该数组表示元素是否在selection
。 然后,要检查 rows 中的两个条目是否都在selection
,您可以将all
与axis=1
,这将给出一个一维布尔数组,说明要保留哪些行。 我们最终用它索引:
group[np.isin(group, selection).all(axis=1)]
样本:
>>> group
array([[ 1, 6],
[ 1, 0],
[ 2, 1],
[40196, 40197],
[40196, 40198],
[40196, 40199]])
>>> selection
array([ 1, 2, 3, 4, 5, 6, 40196, 40199])
>>> np.isin(group, selection)
array([[ True, True],
[ True, False],
[ True, True],
[ True, False],
[ True, False],
[ True, True]])
>>> np.isin(group, selection).all(axis=1)
array([ True, False, True, False, False, True])
>>> group[np.isin(group, selection).all(axis=1)]
array([[ 1, 6],
[ 2, 1],
[40196, 40199]])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.