[英]np.where checking also for subelements in multidimensional arrays
我有兩個具有相同第二維的多維數組。 我想確保第一個數組的任何元素(即沒有行)也是第二個數組的行。
為此,我使用numpy.where
,但它的行為也在檢查相同位置的子元素。 例如考慮這個代碼:
x = np.array([[0,1,2,3], [4,0,6,9]])
z= np.array([[0,1,2,3], [5, 11, 6,98]])
for el in x:
print(np.where(z==el))
它打印:
(array([0, 0, 0, 0]), array([0, 1, 2, 3]))
(array([1]), array([2]))
其中第一個結果是由於第一個數組相等,第二個結果是因為z[1]
和x[1]
都有6
作為第三個元素。 有沒有辦法告訴np.where
只返回嚴格相等元素的索引,即上面例子中的0
?
[i for i, e in enumerate(x) if (e == z).all(1).any()]
x = np.array([[0,1,2,3], [4,0,6,9], [4,0,6,19]])
z= np.array([[4,0,6,9], [0,1,2,3]])
[i for i, e in enumerate(x) if (e == z).all(1).any()]
輸出:
[0, 1]
哪里簡單地返回您的條件的索引 - 這里是元素明智的相等
您可以使用矢量化操作找到重復項:
duplicates = (x[:, None] == z).all(-1).any(-1)
要獲取重復值,請使用掩碼
x[duplicates]
在這個例子中:
duplicates = [True False]
x[duplicates] = [[0, 1, 2, 3]]
[:, None]
all(-1)
any(-1)
伙計,自從np.unique
添加了axis
參數以來,我還沒有機會鏈接到這個答案。 歸功於@Jaime
vview = lambda a: np.ascontiguousarray(a).view(np.dtype((np.void, a.dtype.itemsize * a.shape[1])))
基本上,這需要矩陣的“行”並將它們轉換為行的原始數據流上的一維視圖數組。 這使您可以比較行,就好像它們是單個值一樣。
然后就相當簡單了:
print(np.where(vview(x) == vview(z).T))
(array([0], dtype=int64), array([0], dtype=int64))
表示x
的第一行與z
的第一行匹配
如果您只想知道x
的行是否在z
行中:
print(np.where(np.isin(vview(x), vview(z)).squeeze()))
(array([0], dtype=int64),)
與@mujjiga 在大數組上相比的檢查時間:
x = np.random.randint(10, size = (1000, 4))
z = np.random.randint(10, size = (1000, 4))
%timeit np.where(np.isin(vview(x), vview(z)).squeeze())
365 µs ± 13.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit [i for i, e in enumerate(x) if (e == z).all(1).any()] # @mujjiga
21.3 ms ± 1.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit np.where((x[:, None] == z).all(-1).any(-1)) # @orgoro
20 ms ± 767 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
因此,循環和切片的速度提高了 60 倍,這可能是由於快速短路並且只比較了 1/4 的值
好吧,對於 2D 數組,以下內容可能有用。 我認為您必須小心檢查浮點運算中的ee==0
。
import numpy as np
aa = np.arange(16).reshape(4,4)
# we are trying to find the row in aa which is equal to bb
bb = np.asarray([0,1,2,3])
cc = bb[None,:]
dd = aa - cc
ee = np.linalg.norm(dd,axis=1)
idx = np.where(ee==0)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.