簡體   English   中英

將 numpy.linalg.multi_dot 用於 (N, M, M) 形狀的 3 維數組

[英]use numpy.linalg.multi_dot for 3-dimensional arrays of (N, M, M) shape

這是將 np.linalg.multi_dot() 函數與 Nx2x2 數組(如 functools.reduce(np.matmul, Nx2x2_arrays))一起使用的合理方法嗎? 請看下面的例子。

import numpy as np
from functools import reduce

m1 = np.array(range(16)).reshape(4, 2, 2)
m2 = m1.copy()
m3 = m1.copy()

reduce(np.matmul, (m1, m2, m3))

結果 - 4x2x2 數組:

array([[[    6,    11],
        [   22,    39]],

       [[  514,   615],
        [  738,   883]],

       [[ 2942,  3267],
        [ 3630,  4031]],

       [[ 8826,  9503],
        [10234, 11019]]])

如您所見, np.matmul 將 4x2x2 3-D 數組視為 2x2 矩陣的 1-D 數組。 我可以使用 np.linalg.multi_dot() 而不是 reduce(np.matmul) 來做同樣的事情,如果是,它會導致任何性能改進嗎?

np.linalg.multi_dot()嘗試通過找到導致總體乘法最少的點積的順序來優化操作。

由於您所有的矩陣都是方陣,點積的順序無關緊要,您將始終得到相同數量的乘法。

在內部, np.linalg.multi_dot()不運行任何 C 代碼,而只是調用np.dot() ,因此您可以執行相同的操作:

functools.reduce(np.matmul, (m1, m2, m3))

或者干脆

m1 @ m2 @ m3

你也可以使用np.einsum()

np.einsum('ijk,ikl,ilm->ijm',m1,m2,m3)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM