[英]Numpy element-wise dot product without loop and memory error
我正在用numpy處理一個簡單的問題。 我有兩個矩陣列表-例如A,B
編碼為形狀分別為(n,p,q)
和(n,q,r)
3D數組。
我想計算他們的逐元素點積,即3D數組C
這樣C[i,j,l] = sum A[i,j,:] B[i,:,l]
。 從數學上來說,這非常簡單,但是我必須遵循以下規則:
1)我只能使用numpy函數( dot
, tensordot
, einsum
等):沒有循環&cie。 這是因為我希望它能在我的gpu(帶有cupy)上工作,並且循環很糟糕。 我希望所有操作都在當前設備上進行。
2)由於我的數據可能很大,通常A
和B
已經占用了幾十Mb的內存,所以我不想構建形狀大於(n,p,q),(n,q,r),(n,p,r)
(不必存儲任何中間4D數組)。
例如,我在那里找到的解決方案正在使用:
C = np.sum(np.transpose(A,(0,2,1)).reshape(n,p,q,1)*B.reshape(n,q,1,r),-3)
從數學上講是正確的,但是它隱含了中間創建(n,p,q,r)數組的過程,這個數組對於我的目的來說太大了。
我遇到類似的麻煩
C = np.einsum('ipq,iqr->ipr',A,B)
我不知道底層的操作和構造是什么,但是它總是會導致內存錯誤。
另一方面,有些天真,例如:
C = np.array([A[i].dot(B[i]) for i in range(n)])
就內存而言似乎不錯,但在我的gpu上效率不高:列表似乎是在CPU上構建的,將其重新分配給gpu的速度很慢(如果有一種很友好的方式編寫它,那將是一個不錯的解決方案!)
謝謝您的幫助 !
您需要numpy.matmul
( 此處為cupy版本 )。 matmul
是一個“廣播”矩陣乘法。
我認為人們已經知道numpy.dot
語義很numpy.dot
,並且需要廣播矩陣乘法,但是直到python得到@
運算符之前,引入該更改的動力並不大。 我看不到dot
到任何地方,但是我懷疑更好的語義和A @ B
的易用性意味着隨着人們發現新的函數和運算符, dot
不再受歡迎。
您嘗試避免的迭代方法可能還不錯。 例如,考慮以下時間:
In [51]: A = np.ones((100,10,10))
In [52]: timeit np.array([A[i].dot(A[i]) for i in range(A.shape[0])])
439 µs ± 1.35 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [53]: timeit np.einsum('ipq,iqr->ipr',A,A)
428 µs ± 170 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [54]: timeit A@A
426 µs ± 54.6 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
在這種情況下,這三個時間都差不多。
但是我將后面的維度加倍,實際上迭代方法更快:
In [55]: A = np.ones((100,20,20))
In [56]: timeit np.array([A[i].dot(A[i]) for i in range(A.shape[0])])
702 µs ± 1.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [57]: timeit np.einsum('ipq,iqr->ipr',A,A)
1.89 ms ± 1.63 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [58]: timeit A@A
1.89 ms ± 490 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
當我將20更改為30和40時,也保持相同的模式。令我感到驚訝的是, matmul
時間與einsum
如此接近。
我想我可以嘗試將它們推到內存極限。 我沒有想要測試該方面的后端。
一旦考慮了內存管理問題,針對一個大問題進行的適度迭代就不會太可怕了。 您想要避免的事情是numpy,它涉及一個簡單任務的多次迭代。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.