簡體   English   中英

Python - 二維 Numpy 數組的交集

[英]Python - Intersection of 2D Numpy Arrays

我正在拼命尋找一種有效的方法來檢查兩個 2D numpy 數組是否相交。

所以我所擁有的是兩個數組,其中包含任意數量的二維數組,例如:

A=np.array([[2,3,4],[5,6,7],[8,9,10]])
B=np.array([[5,6,7],[1,3,4]])
C=np.array([[1,2,3],[6,6,7],[10,8,9]])

如果至少有一個向量與另一個數組中的另一個向量相交,我只需要一個 True ,否則為 false。 所以它應該給出這樣的結果:

f(A,B)  -> True
f(A,C)  -> False

我對 python 有點陌生,起初我用 python 列表編寫了我的程序,它可以工作,但當然效率非常低。 該程序需要數天才能完成,所以我現在正在研究numpy.array解決方案,但這些數組確實不是那么容易處理。

這是有關我的程序和 Python 列表解決方案的一些上下文:

我正在做的是類似於 3 維中的自我避免隨機游走。 http://en.wikipedia.org/wiki/Self-avoiding_walk 但是,我沒有進行隨機游走並希望它達到理想的長度(例如,我希望鏈由 1000 個珠子組成)而不會走到死胡同,我執行以下操作:

我創建了一個具有所需長度 N 的“扁平”鏈:

X=[]
for i in range(0,N+1):
    X.append((i,0,0))

現在我折疊這條扁平鏈:

  1. 隨機選擇其中一個元素(“pivotelement”)
  2. 隨機選擇一個方向(樞軸左側或右側的所有元素)
  3. 從空間中 9 種可能的旋轉中隨機選擇一種(3 軸 * 3 種可能的旋轉 90°、180°、270°)
  4. 使用所選旋轉旋轉所選方向的所有元素
  5. 檢查所選方向的新元素是否與另一個方向相交
  6. 沒有交集 -> 接受新配置,否則 -> 保留舊鏈。

步驟 1.-6。 必須執行大量的時間(例如,對於長度為 1000,~5000 次的鏈),因此必須有效地完成這些步驟。 我的基於列表的解決方案如下:

def PivotFold(chain):
randPiv=random.randint(1,N)  #Chooses a random pivotelement, N is the Chainlength
Pivot=chain[randPiv]  #get that pivotelement
C=[]  #C is going to be a shifted copy of the chain
intersect=False
for j in range (0,N+1):   # Here i shift the hole chain to get the pivotelement to the origin, so i can use simple rotations around the origin
    C.append((chain[j][0]-Pivot[0],chain[j][1]-Pivot[1],chain[j][2]-Pivot[2]))
rotRand=random.randint(1,18)  # rotRand is used to choose a direction and a Rotation (2 possible direction * 9 rotations = 18 possibilitys)
#Rotations around Z-Axis
if rotRand==1:
    for j in range (randPiv,N+1):
        C[j]=(-C[j][1],C[j][0],C[j][2])
        if C[0:randPiv].__contains__(C[j])==True:
            intersect=True
            break
elif rotRand==2:
    for j in range (randPiv,N+1):
        C[j]=(C[j][1],-C[j][0],C[j][2])
        if C[0:randPiv].__contains__(C[j])==True:
            intersect=True
            break
...etc
if intersect==False: # return C if there was no intersection in C
    Shizz=C
else:
    Shizz=chain
return Shizz

函數 PivotFold(chain) 將在最初平坦的鏈 X 上大量使用。 它寫得非常天真,所以也許你有一些技巧可以改進這個 ^^ 我認為 numpyarrays 會很好,因為我可以有效地移動和旋轉整個鏈而無需循環遍歷所有元素......

這應該這樣做:

In [11]:

def f(arrA, arrB):
    return not set(map(tuple, arrA)).isdisjoint(map(tuple, arrB))
In [12]:

f(A, B)
Out[12]:
True
In [13]:

f(A, C)
Out[13]:
False
In [14]:

f(B, C)
Out[14]:
False

找交點? 好的, set聽起來是一個合乎邏輯的選擇。 但是numpy.arraylist不是可散列的嗎? 好的,將它們轉換為tuple 這就是想法。

一種numpy的做法涉及非常難以閱讀的廣播:

In [34]:

(A[...,np.newaxis]==B[...,np.newaxis].T).all(1)
Out[34]:
array([[False, False],
       [ True, False],
       [False, False]], dtype=bool)
In [36]:

(A[...,np.newaxis]==B[...,np.newaxis].T).all(1).any()
Out[36]:
True

一些時間結果:

In [38]:
#Dan's method
%timeit set_comp(A,B)
10000 loops, best of 3: 34.1 µs per loop
In [39]:
#Avoiding lambda will speed things up
%timeit f(A,B)
10000 loops, best of 3: 23.8 µs per loop
In [40]:
#numpy way probably will be slow, unless the size of the array is very big (my guess)
%timeit (A[...,np.newaxis]==B[...,np.newaxis].T).all(1).any()
10000 loops, best of 3: 49.8 µs per loop

此外, numpy方法將A[...,np.newaxis]==B[...,np.newaxis].T RAM,因為A[...,np.newaxis]==B[...,np.newaxis].T步驟創建了一個 3D 數組。

使用此處概述的相同想法,您可以執行以下操作:

def make_1d_view(a):
    a = np.ascontiguousarray(a)
    dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
    return a.view(dt).ravel()

def f(a, b):
    return len(np.intersect1d(make_1d_view(A), make_1d_view(b))) != 0

>>> f(A, B)
True
>>> f(A, C)
False

這不適用於浮點類型(它不會將 +0.0 和 -0.0 視為相同的值),並且np.intersect1d使用排序,因此它具有線性而非線性的性能。 您可以通過在代碼中復制np.intersect1d的源來壓縮一些性能,而不是檢查返回數組的長度, np.any在布爾索引數組上調用np.any

您還可以通過一些np.tilenp.swapaxes業務完成工作!

def intersect2d(X, Y):
        """
        Function to find intersection of two 2D arrays.
        Returns index of rows in X that are common to Y.
        """
        X = np.tile(X[:,:,None], (1, 1, Y.shape[0]) )
        Y = np.swapaxes(Y[:,:,None], 0, 2)
        Y = np.tile(Y, (X.shape[0], 1, 1))
        eq = np.all(np.equal(X, Y), axis = 1)
        eq = np.any(eq, axis = 1)
        return np.nonzero(eq)[0]

要更具體地回答這個問題,您只需要檢查返回的數組是否為空。

這應該快得多,它不像 for 循環解決方案那樣 O(n^2),但它不是完全 numpythonic。 不確定在這里如何更好地利用 numpy

def set_comp(a, b):
   sets_a = set(map(lambda x: frozenset(tuple(x)), a))
   sets_b = set(map(lambda x: frozenset(tuple(x)), b))
   return not sets_a.isdisjoint(sets_b)

如果兩個數組設置了子數組,我認為你想要真! 你可以使用這個:

def(A,B):
 for i in A:
  for j in B:
   if i==j
   return True
 return False 

這個問題可以使用numpy_indexed包有效地解決(免責聲明:我是它的作者):

import numpy_indexed as npi
len(npi.intersection(A, B)) > 0

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM