簡體   English   中英

如何在3d numpy數組中找到2d數組的行

[英]How to find row of 2d array in 3d numpy array

我試圖找到一個二維數組出現在3d numpy ndarray中的行。 這是我的意思的一個例子。 給:

arr = [[[0, 3], [3, 0]],
       [[0, 0], [0, 0]],
       [[3, 3], [3, 3]],
       [[0, 3], [3, 0]]]

我想找到所有出現的:

[[0, 3], [3, 0]]

我想要的結果是:

[0, 3]

我試圖使用argwhere但不幸的是讓我無處可去。 有任何想法嗎?

嘗試

np.argwhere(np.all(arr==[[0,3], [3,0]], axis=(1,2)))

這個怎么運作:

arr == [[0,3], [3,0]]返回

array([[[ True,  True],
        [ True,  True]],

       [[ True, False],
        [False,  True]],

       [[False,  True],
        [ True, False]],

       [[ True,  True],
        [ True,  True]]], dtype=bool)

這是一個三維數組,其中最內軸是2.此軸的值為:

[True, True]
[True, True]
[True, False]
[False, True]
[False, True]
[True, False]
[True, True]
[True, True]

現在使用np.all(arr==[[0,3], [3,0]], axis=2)您將檢查行上的兩個元素是否為True ,其形狀是否將減少為(4,2)來自(4,2,2)。 像這樣:

array([[ True,  True],
       [False, False],
       [False, False],
       [ True,  True]], dtype=bool)

您需要再減少一步,因為您希望它們兩者相同( [0, 3][3, 0] 。您可以通過減少結果(現在最內軸為1)來實現:

np.all(np.all(test, axis = 2), axis=1)

或者您也可以通過為軸參數提供元組來逐步執行相同的操作(首先是最里面,然后再高一步)。 結果將是:

array([ True, False, False,  True], dtype=bool)

numpy_indexed包中的'contains'函數(免責聲明:我是它的作者)可用於進行此類查詢。 它實現了類似於Saullo提供的解決方案。

import numpy_indexed as npi
test = [[[0, 3], [3, 0]]]
# check which elements of arr are present in test (checked along axis=0 by default)
flags = npi.contains(test, arr)
# if you want the indexes:
idx = np.flatnonzero(flags)

在定義一個新數據類型后,可以使用np.in1d ,該數據類型將包含arr中每行的內存大小。 要定義此類數據類型:

mydtype = np.dtype((np.void, arr.dtype.itemsize*arr.shape[1]*arr.shape[2]))

那么你必須將你的arr轉換為一維數組,其中每一行都有arr.shape[1]*arr.shape[2]元素:

aView = np.ascontiguousarray(arr).flatten().view(mydtype)

您現在可以查找二維數組模式[[0, 3], [3, 0]] dtype [[0, 3], [3, 0]] ,它們也必須轉換為dtype

bView = np.array([[0, 3], [3, 0]]).flatten().view(mydtype)

現在,您可以檢查的occurrencies bViewaView

np.in1d(aView, bView)
#array([ True, False, False,  True], dtype=bool)

例如,使用np.where可以輕松地將此掩碼轉換為索引。

計時(更新)

以下函數用於實現此方法:

def check2din3d(b, a):
        """
        Return where `b` (2D array) appears in `a` (3D array) along `axis=0`
        """
        mydtype = np.dtype((np.void, a.dtype.itemsize*a.shape[1]*a.shape[2]))
        aView = np.ascontiguousarray(a).flatten().view(mydtype)
        bView = np.ascontiguousarray(b).flatten().view(mydtype)
        return np.in1d(aView, bView)

考慮到@ayhan評論的更新時間表明,這種方法可以更快地在np.argwhere,但是不同並不重要,對於像下面這樣的大型數組,@ ayhan的方法要快得多:

arrLarge = np.concatenate([arr]*10000000)
arrLarge = np.concatenate([arrLarge]*10, axis=2)

pattern = np.ascontiguousarray([[0,3]*10, [3,0]*10])

%timeit np.argwhere(np.all(arrLarger==pattern, axis=(1,2)))
#1 loops, best of 3: 2.99 s per loop

%timeit check2din3d(pattern, arrLarger)
#1 loops, best of 3: 4.65 s per loop

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM