[英]Filtering rows of numpy array based on whether row elements are in another array
我有一個數組group
,它是 Nx2:
array([[ 1, 6],
[ 1, 0],
[ 2, 1],
...,
[40196, 40197],
[40196, 40198],
[40196, 40199]], dtype=uint32)
另一個數組selection
是 (M,):
array([3216, 3217, 3218, ..., 8039])
我想創建一個新數組,其中包含兩個元素都在selection
的group
所有行。 我是這樣做的:
np.array([(i,j) for (i,j) in group if i in selection and j in selection])
這有效,但我知道必須有一種更有效的方法來利用一些 numpy 函數。
您可以使用np.isin
獲取與group
形狀相同的布爾數組,該數組表示元素是否在selection
。 然后,要檢查 rows 中的兩個條目是否都在selection
,您可以將all
與axis=1
,這將給出一個一維布爾數組,說明要保留哪些行。 我們最終用它索引:
group[np.isin(group, selection).all(axis=1)]
樣本:
>>> group
array([[ 1, 6],
[ 1, 0],
[ 2, 1],
[40196, 40197],
[40196, 40198],
[40196, 40199]])
>>> selection
array([ 1, 2, 3, 4, 5, 6, 40196, 40199])
>>> np.isin(group, selection)
array([[ True, True],
[ True, False],
[ True, True],
[ True, False],
[ True, False],
[ True, True]])
>>> np.isin(group, selection).all(axis=1)
array([ True, False, True, False, False, True])
>>> group[np.isin(group, selection).all(axis=1)]
array([[ 1, 6],
[ 2, 1],
[40196, 40199]])
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.