[英]Combining element-wise and matrix multiplication with multi-dimensional arrays in NumPy
我有兩個多維 NumPy 數組, A
和B
, A.shape = (K, d, N)
和B.shape = (K, N, d)
。 我想在軸 0 ( K
) 上執行元素操作,該操作是在軸 1 和 2 ( d, N
和N, d
) 上的矩陣乘法。 所以結果應該是一個多維數組C
其中C.shape = (K, d, d)
,所以C[k] = np.dot(A[k], B[k])
。 一個簡單的實現看起來像這樣:
C = np.vstack([np.dot(A[k], B[k])[np.newaxis, :, :] for k in xrange(K)])
但是這個實現很慢。 稍微快一點的方法如下所示:
C = np.dot(A, B)[:, :, 0, :]
它在多維數組上使用np.dot
的默認行為,給我一個形狀為(K, d, K, d)
的數組。 但是,這種方法計算所需的答案K
次(沿軸 2 的每個條目都相同)。 漸近地它會比第一種方法慢,但開銷要少得多。 我也知道以下方法:
from numpy.core.umath_tests import matrix_multiply
C = matrix_multiply(A, B)
但我不保證此功能將可用。 因此,我的問題是,NumPy 是否提供了有效執行此操作的標准方法? 一般適用於多維數組的答案將是完美的,但僅針對這種情況的答案也會很棒。
編輯:正如@Juh_ 所指出的,第二種方法是不正確的。 正確的版本是:
C = np.dot(A, B).diagonal(axis1=0, axis2=2).transpose(2, 0, 1)
但是增加的開銷使它比第一種方法慢,即使對於小矩陣也是如此。 最后一種方法是在我所有的時序測試中,無論是小矩陣還是大矩陣,都遙遙領先。 如果沒有更好的解決方案出現,我現在強烈考慮使用它,即使這意味着將numpy.core.umath_tests
庫(用 C 編寫)復制到我的項目中。
您的問題的可能解決方案是:
C = np.sum(A[:,:,:,np.newaxis]*B[:,np.newaxis,:,:],axis=2)
然而:
順便說一句,請注意:
C = np.dot(A, B)[:, :, 0, :]
沒有給出正確的結果。 它讓我被騙了,因為我首先通過將結果與 np.dot 命令給出的結果進行比較來檢查我的方法。
我的項目中有同樣的問題。 我能想到的最好的方法是,我認為它比使用vstack
快一點(可能快 10%):
K, d, N = A.shape
C = np.empty((K, d, d))
for k in xrange(K):
C[k] = np.dot(A[k], B[k])
我很想看到更好的解決方案,我不太明白人們會如何使用tensordot
來做到這一點。
一個非常靈活、緊湊且快速的解決方案:
C = np.einsum('Kab,Kbc->Kac', A, B, optimize=True)
確認:
import numpy as np
K = 10
d = 5
N = 3
A = np.random.rand(K,d,N)
B = np.random.rand(K,N,d)
C_old = np.dot(A, B).diagonal(axis1=0, axis2=2).transpose(2, 0, 1)
C_new = np.einsum('Kab,Kbc->Kac', A, B)
print(np.max(C_old-C_new)) # should be 0 or a very small number
對於大型多維數組,可選參數optimize=True
可以為您節省大量時間。 您可以在此處了解einsum :
https://ajcr.net/Basic-guide-to-einsum/
https://rockt.github.io/2018/04/30/einsum
https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
引用:
愛因斯坦求和約定可用於計算許多多維線性代數數組運算。 einsum提供了一種簡潔的方式來表示這些。 這些操作的非詳盡列表是:
數組的跟蹤, numpy.trace 。
返回對角線numpy.diag 。
數組軸求和, numpy.sum 。
換位和排列, numpy.transpose 。
矩陣乘法和點積, numpy.matmul numpy.dot 。
矢量內積和外積, numpy.inner numpy.outer 。
廣播,元素和標量乘法, numpy.multiply 。
張量收縮, numpy.tensordot 。
鏈式數組操作,以高效的計算順序, numpy.einsum_path 。
你可以做
np.matmul(A, B)
查看https://numpy.org/doc/stable/reference/generated/numpy.matmul.html 。
對於足夠大的K
應該比 einsum 快。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.