簡體   English   中英

將元素乘法和矩陣乘法與 NumPy 中的多維數組相結合

[英]Combining element-wise and matrix multiplication with multi-dimensional arrays in NumPy

我有兩個多維 NumPy 數組, ABA.shape = (K, d, N)B.shape = (K, N, d) 我想在軸 0 ( K ) 上執行元素操作,該操作是在軸 1 和 2 ( d, NN, d ) 上的矩陣乘法。 所以結果應該是一個多維數組C其中C.shape = (K, d, d) ,所以C[k] = np.dot(A[k], B[k]) 一個簡單的實現看起來像這樣:

C = np.vstack([np.dot(A[k], B[k])[np.newaxis, :, :] for k in xrange(K)])

但是這個實現很 稍微快一點的方法如下所示:

C = np.dot(A, B)[:, :, 0, :]

它在多維數組上使用np.dot的默認行為,給我一個形狀為(K, d, K, d)的數組。 但是,這種方法計算所需的答案K次(沿軸 2 的每個條目都相同)。 漸近地它會比第一種方法慢,但開銷要少得多。 我也知道以下方法:

from numpy.core.umath_tests import matrix_multiply
C = matrix_multiply(A, B)

但我不保證此功能將可用。 因此,我的問題是,NumPy 是否提供了有效執行此操作的標准方法? 一般適用於多維數組的答案將是完美的,但僅針對這種情況的答案也會很棒。

編輯:正如@Juh_ 所指出的,第二種方法是不正確的。 正確的版本是:

C = np.dot(A, B).diagonal(axis1=0, axis2=2).transpose(2, 0, 1)

但是增加的開銷使它比第一種方法,即使對於小矩陣也是如此。 最后一種方法是在我所有的時序測試中,無論是小矩陣還是大矩陣,都遙遙領先。 如果沒有更好的解決方案出現,我現在強烈考慮使用它,即使這意味着將numpy.core.umath_tests庫(用 C 編寫)復制到我的項目中。

您的問題的可能解決方案是:

C = np.sum(A[:,:,:,np.newaxis]*B[:,np.newaxis,:,:],axis=2)

然而:

  1. 只有當 K 遠大於 d 和 N 時,它才比 vstack 方法更快
  2. 它們可能是一些內存問題:在上面的解決方案中,分配了一個 KxdxNxd 數組(即所有可能的乘積對,在求和之前)。 實際上,由於內存不足,我無法使用大 K、d 和 N 進行測試。

順便說一句,請注意:

C = np.dot(A, B)[:, :, 0, :]

沒有給出正確的結果。 它讓我被騙了,因為我首先通過將結果與 np.dot 命令給出的結果進行比較來檢查我的方法。

我的項目中有同樣的問題。 我能想到的最好的方法是,我認為它比使用vstack快一點(可能快 10%):

K, d, N = A.shape
C = np.empty((K, d, d))
for k in xrange(K):
    C[k] = np.dot(A[k], B[k])

我很想看到更好的解決方案,我不太明白人們會如何使用tensordot來做到這一點。

一個非常靈活、緊湊且快速的解決方案:

C = np.einsum('Kab,Kbc->Kac', A, B, optimize=True)

確認:

import numpy as np
K = 10
d = 5
N = 3
A = np.random.rand(K,d,N)
B = np.random.rand(K,N,d)
C_old = np.dot(A, B).diagonal(axis1=0, axis2=2).transpose(2, 0, 1)
C_new = np.einsum('Kab,Kbc->Kac', A, B)
print(np.max(C_old-C_new))  # should be 0 or a very small number

對於大型多維數組,可選參數optimize=True可以為您節省大量時間。 您可以在此處了解einsum

https://ajcr.net/Basic-guide-to-einsum/

https://rockt.github.io/2018/04/30/einsum

https://numpy.org/doc/stable/reference/generated/numpy.einsum.html

引用:

愛因斯坦求和約定可用於計算許多多維線性代數數組運算。 einsum提供了一種簡潔的方式來表示這些。 這些操作的非詳盡列表是:

  • 數組的跟蹤, numpy.trace

  • 返回對角線numpy.diag

  • 數組軸求和, numpy.sum

  • 換位和排列, numpy.transpose

  • 矩陣乘法和點積, numpy.matmul numpy.dot

  • 矢量內積和外積, numpy.inner numpy.outer

  • 廣播,元素和標量乘法, numpy.multiply

  • 張量收縮, numpy.tensordot

  • 鏈式數組操作,以高效的計算順序, numpy.einsum_path

你可以做

np.matmul(A, B)

查看https://numpy.org/doc/stable/reference/generated/numpy.matmul.html

對於足夠大的K應該比 einsum 快。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM