簡體   English   中英

np.where 也檢查多維數組中的子元素

[英]np.where checking also for subelements in multidimensional arrays

我有兩個具有相同第二維的多維數組。 我想確保第一個數組的任何元素(即沒有行)也是第二個數組的行。

為此,我使用numpy.where ,但它的行為也在檢查相同位置的子元素。 例如考慮這個代碼:

x = np.array([[0,1,2,3], [4,0,6,9]])
z= np.array([[0,1,2,3], [5, 11, 6,98]])
for el in x:
    print(np.where(z==el))

它打印:

(array([0, 0, 0, 0]), array([0, 1, 2, 3]))
(array([1]), array([2]))

其中第一個結果是由於第一個數組相等,第二個結果是因為z[1]x[1]都有6作為第三個元素。 有沒有辦法告訴np.where只返回嚴格相等元素的索引,即上面例子中的0

[i for i, e in enumerate(x) if (e == z).all(1).any()]

測試用例:

x = np.array([[0,1,2,3], [4,0,6,9], [4,0,6,19]])
z= np.array([[4,0,6,9], [0,1,2,3]])

[i for i, e in enumerate(x) if (e == z).all(1).any()]

輸出:

[0, 1]

哪里簡單地返回您的條件的索引 - 這里是元素明智的相等

回答

您可以使用矢量化操作找到重復項:

duplicates = (x[:, None] == z).all(-1).any(-1)

獲取值

要獲取重復值,請使用掩碼

x[duplicates]

在這個例子中:

duplicates = [True False]

x[duplicates] = [[0, 1, 2, 3]]

邏輯

  1. 擴展數組[:, None]
  2. 僅查找整行匹配all(-1)
  3. 返回至少有一個匹配any(-1)

伙計,自從np.unique添加了axis參數以來,我還沒有機會鏈接到這個答案。 歸功於@Jaime

vview = lambda a: np.ascontiguousarray(a).view(np.dtype((np.void, a.dtype.itemsize * a.shape[1])))

基本上,這需要矩陣的“行”並將它們轉換為行的原始數據流上的一維視圖數組。 這使您可以比較行,就好像它們是單個值一樣。

然后就相當簡單了:

print(np.where(vview(x) == vview(z).T))
(array([0], dtype=int64), array([0], dtype=int64))

表示x的第一行與z的第一行匹配

如果您只想知道x的行是否在z行中:

print(np.where(np.isin(vview(x), vview(z)).squeeze()))
(array([0], dtype=int64),)

與@mujjiga 在大數組上相比的檢查時間:

x = np.random.randint(10, size = (1000, 4))

z = np.random.randint(10, size = (1000, 4))

%timeit np.where(np.isin(vview(x), vview(z)).squeeze())
365 µs ± 13.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit [i for i, e in enumerate(x) if (e == z).all(1).any()]  # @mujjiga
21.3 ms ± 1.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit np.where((x[:, None] == z).all(-1).any(-1))  # @orgoro
20 ms ± 767 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

因此,循環切片的速度提高了 60 倍,這可能是由於快速短路並且只比較了 1/4 的值

好吧,對於 2D 數組,以下內容可能有用。 我認為您必須小心檢查浮點運算中的ee==0

import numpy as np
aa = np.arange(16).reshape(4,4)
# we are trying to find the row in aa which is equal to bb
bb = np.asarray([0,1,2,3])
cc = bb[None,:]
dd = aa - cc
ee = np.linalg.norm(dd,axis=1)
idx = np.where(ee==0)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM