[英]Remove only rows which contain duplicates within that row of 3D numpy array
我有一個像這樣的3D numpy數組:
>>> a
array([[[0, 1, 2],
[0, 1, 2],
[6, 7, 8]],
[[6, 7, 8],
[0, 1, 2],
[6, 7, 8]],
[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])
我想只刪除那些包含重復項的行。 例如,輸出應如下所示:
>>> remove_row_duplicates(a)
array([[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])
這是我正在使用的功能:
delindices = np.empty(0, dtype=int)
for i in range(len(a)):
_, indices = np.unique(np.around(a[i], decimals=10), axis=0, return_index=True)
if len(indices) < len(a[i]):
delindices = np.append(delindices, i)
a = np.delete(a, delindices, 0)
這很好用,但問題是現在我的數組形狀就像(1000000,7,3)。 for循環在python中非常慢,這需要花費很多時間。 我的原始數組也包含浮點數。 誰有更好的解決方案或誰可以幫助我矢量化這個功能?
沿着每個2D block
的行對其進行排序,即沿axis=1
,然后沿着連續的行查找匹配的行,最后查找沿同一axis=1
any
匹配axis=1
-
b = np.sort(a,axis=1)
out = a[~((b[:,1:] == b[:,:-1]).all(-1)).any(1)]
示例運行說明
輸入數組:
In [51]: a
Out[51]:
array([[[0, 1, 2],
[0, 1, 2],
[6, 7, 8]],
[[6, 7, 8],
[0, 1, 2],
[6, 7, 8]],
[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])
代碼步驟:
# Sort along axis=1, i.e rows in each 2D block
In [52]: b = np.sort(a,axis=1)
In [53]: b
Out[53]:
array([[[0, 1, 2],
[0, 1, 2],
[6, 7, 8]],
[[0, 1, 2],
[6, 7, 8],
[6, 7, 8]],
[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])
In [54]: (b[:,1:] == b[:,:-1]).all(-1) # Look for successive matching rows
Out[54]:
array([[ True, False],
[False, True],
[False, False]])
# Look for matches along each row, which indicates presence
# of duplicate rows within each 2D block in original 2D array
In [55]: ((b[:,1:] == b[:,:-1]).all(-1)).any(1)
Out[55]: array([ True, True, False])
# Invert those as we need to remove those cases
# Finally index with boolean indexing and get the output
In [57]: a[~((b[:,1:] == b[:,:-1]).all(-1)).any(1)]
Out[57]:
array([[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])
您可以使用廣播輕松完成此操作,但由於您處理的不僅僅是2D數組,因此它不會像您預期的那樣優化,甚至在某些情況下也會非常慢。 相反,您可以使用受Jaime答案啟發的以下方法:
In [28]: u = np.unique(arr.view(np.dtype((np.void, arr.dtype.itemsize*arr.shape[1])))).view(arr.dtype).reshape(-1, arr.shape[1])
In [29]: inds = np.where((arr == u).all(2).sum(0) == u.shape[1])
In [30]: arr[inds]
Out[30]:
array([[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.