[英]How to find row of 2d array in 3d numpy array
我試圖找到一個二維數組出現在3d numpy ndarray中的行。 這是我的意思的一個例子。 給:
arr = [[[0, 3], [3, 0]],
[[0, 0], [0, 0]],
[[3, 3], [3, 3]],
[[0, 3], [3, 0]]]
我想找到所有出現的:
[[0, 3], [3, 0]]
我想要的結果是:
[0, 3]
我試圖使用argwhere
但不幸的是讓我無處可去。 有任何想法嗎?
嘗試
np.argwhere(np.all(arr==[[0,3], [3,0]], axis=(1,2)))
這個怎么運作:
arr == [[0,3], [3,0]]
返回
array([[[ True, True],
[ True, True]],
[[ True, False],
[False, True]],
[[False, True],
[ True, False]],
[[ True, True],
[ True, True]]], dtype=bool)
這是一個三維數組,其中最內軸是2.此軸的值為:
[True, True]
[True, True]
[True, False]
[False, True]
[False, True]
[True, False]
[True, True]
[True, True]
現在使用np.all(arr==[[0,3], [3,0]], axis=2)
您將檢查行上的兩個元素是否為True
,其形狀是否將減少為(4,2)來自(4,2,2)。 像這樣:
array([[ True, True],
[False, False],
[False, False],
[ True, True]], dtype=bool)
您需要再減少一步,因為您希望它們兩者相同( [0, 3]
和[3, 0]
。您可以通過減少結果(現在最內軸為1)來實現:
np.all(np.all(test, axis = 2), axis=1)
或者您也可以通過為軸參數提供元組來逐步執行相同的操作(首先是最里面,然后再高一步)。 結果將是:
array([ True, False, False, True], dtype=bool)
numpy_indexed包中的'contains'函數(免責聲明:我是它的作者)可用於進行此類查詢。 它實現了類似於Saullo提供的解決方案。
import numpy_indexed as npi
test = [[[0, 3], [3, 0]]]
# check which elements of arr are present in test (checked along axis=0 by default)
flags = npi.contains(test, arr)
# if you want the indexes:
idx = np.flatnonzero(flags)
在定義一個新數據類型后,可以使用np.in1d
,該數據類型將包含arr
中每行的內存大小。 要定義此類數據類型:
mydtype = np.dtype((np.void, arr.dtype.itemsize*arr.shape[1]*arr.shape[2]))
那么你必須將你的arr
轉換為一維數組,其中每一行都有arr.shape[1]*arr.shape[2]
元素:
aView = np.ascontiguousarray(arr).flatten().view(mydtype)
您現在可以查找二維數組模式[[0, 3], [3, 0]]
dtype
[[0, 3], [3, 0]]
,它們也必須轉換為dtype
:
bView = np.array([[0, 3], [3, 0]]).flatten().view(mydtype)
現在,您可以檢查的occurrencies bView
在aView
:
np.in1d(aView, bView)
#array([ True, False, False, True], dtype=bool)
例如,使用np.where
可以輕松地將此掩碼轉換為索引。
以下函數用於實現此方法:
def check2din3d(b, a):
"""
Return where `b` (2D array) appears in `a` (3D array) along `axis=0`
"""
mydtype = np.dtype((np.void, a.dtype.itemsize*a.shape[1]*a.shape[2]))
aView = np.ascontiguousarray(a).flatten().view(mydtype)
bView = np.ascontiguousarray(b).flatten().view(mydtype)
return np.in1d(aView, bView)
考慮到@ayhan評論的更新時間表明,這種方法可以更快地在np.argwhere,但是不同並不重要,對於像下面這樣的大型數組,@ ayhan的方法要快得多:
arrLarge = np.concatenate([arr]*10000000)
arrLarge = np.concatenate([arrLarge]*10, axis=2)
pattern = np.ascontiguousarray([[0,3]*10, [3,0]*10])
%timeit np.argwhere(np.all(arrLarger==pattern, axis=(1,2)))
#1 loops, best of 3: 2.99 s per loop
%timeit check2din3d(pattern, arrLarger)
#1 loops, best of 3: 4.65 s per loop
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.