[英]Efficiently find row intersections of two 2-D numpy arrays
我試圖找出一種有效的方法來找到兩個np.arrays
行交叉點。
兩個數組具有相同的形狀,並且每行中的重復值不會發生。
例如:
import numpy as np
a = np.array([[2,5,6],
[8,2,3],
[4,1,5],
[1,7,9]])
b = np.array([[2,3,4], # one element(2) in common with a[0] -> 1
[7,4,3], # one element(3) in common with a[1] -> 1
[5,4,1], # three elements(5,4,1) in common with a[2] -> 3
[7,6,9]]) # two element(9,7) in common with a[3] -> 2
我想要的輸出是: np.array([1,1,3,2])
使用循環很容易做到這一點:
def get_intersect1ds(a, b):
result = np.empty(a.shape[0], dtype=np.int)
for i in xrange(a.shape[0]):
result[i] = (len(np.intersect1d(a[i], b[i])))
return result
結果:
>>> get_intersect1ds(a, b)
array([1, 1, 3, 2])
但是有更有效的方法嗎?
如果你在一行中沒有重復項,你可以嘗試復制np.intersect1d
在np.intersect1d
的內容(請參閱此處的源代碼):
>>> c = np.hstack((a, b))
>>> c
array([[2, 5, 6, 2, 3, 4],
[8, 2, 3, 7, 4, 3],
[4, 1, 5, 5, 4, 1],
[1, 7, 9, 7, 6, 9]])
>>> c.sort(axis=1)
>>> c
array([[2, 2, 3, 4, 5, 6],
[2, 3, 3, 4, 7, 8],
[1, 1, 4, 4, 5, 5],
[1, 6, 7, 7, 9, 9]])
>>> c[:, 1:] == c[:, :-1]
array([[ True, False, False, False, False],
[False, True, False, False, False],
[ True, False, True, False, True],
[False, False, True, False, True]], dtype=bool)
>>> np.sum(c[:, 1:] == c[:, :-1], axis=1)
array([1, 1, 3, 2])
這個答案可能不可行,因為如果輸入具有形狀(N,M),它會生成一個大小為(N,M,M)的中間數組,但看看你可以用廣播做什么總是很有趣:
In [43]: a
Out[43]:
array([[2, 5, 6],
[8, 2, 3],
[4, 1, 5],
[1, 7, 9]])
In [44]: b
Out[44]:
array([[2, 3, 4],
[7, 4, 3],
[5, 4, 1],
[7, 6, 9]])
In [45]: (np.expand_dims(a, -1) == np.expand_dims(b, 1)).sum(axis=-1).sum(axis=-1)
Out[45]: array([1, 1, 3, 2])
對於大型數組,通過批量操作可以使該方法更加內存友好。
我想不出一個干凈的純粹解決方案,但以下建議應該加快速度,可能是顯着的:
@autojit
裝飾你的get_intersect1ds
函數一樣簡單 intersect1d
時,傳遞assume_unique = True
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.