繁体   English   中英

将 numpy.linalg.multi_dot 用于 (N, M, M) 形状的 3 维数组

[英]use numpy.linalg.multi_dot for 3-dimensional arrays of (N, M, M) shape

这是将 np.linalg.multi_dot() 函数与 Nx2x2 数组(如 functools.reduce(np.matmul, Nx2x2_arrays))一起使用的合理方法吗? 请看下面的例子。

import numpy as np
from functools import reduce

m1 = np.array(range(16)).reshape(4, 2, 2)
m2 = m1.copy()
m3 = m1.copy()

reduce(np.matmul, (m1, m2, m3))

结果 - 4x2x2 数组:

array([[[    6,    11],
        [   22,    39]],

       [[  514,   615],
        [  738,   883]],

       [[ 2942,  3267],
        [ 3630,  4031]],

       [[ 8826,  9503],
        [10234, 11019]]])

如您所见, np.matmul 将 4x2x2 3-D 数组视为 2x2 矩阵的 1-D 数组。 我可以使用 np.linalg.multi_dot() 而不是 reduce(np.matmul) 来做同样的事情,如果是,它会导致任何性能改进吗?

np.linalg.multi_dot()尝试通过找到导致总体乘法最少的点积的顺序来优化操作。

由于您所有的矩阵都是方阵,点积的顺序无关紧要,您将始终得到相同数量的乘法。

在内部, np.linalg.multi_dot()不运行任何 C 代码,而只是调用np.dot() ,因此您可以执行相同的操作:

functools.reduce(np.matmul, (m1, m2, m3))

或者干脆

m1 @ m2 @ m3

你也可以使用np.einsum()

np.einsum('ijk,ikl,ilm->ijm',m1,m2,m3)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM