简体   繁体   English

将 numpy.linalg.multi_dot 用于 (N, M, M) 形状的 3 维数组

[英]use numpy.linalg.multi_dot for 3-dimensional arrays of (N, M, M) shape

Is that any reasonable way to use np.linalg.multi_dot() function with Nx2x2 arrays like functools.reduce(np.matmul, Nx2x2_arrays)?这是将 np.linalg.multi_dot() 函数与 Nx2x2 数组(如 functools.reduce(np.matmul, Nx2x2_arrays))一起使用的合理方法吗? Please, see example below.请看下面的例子。

import numpy as np
from functools import reduce

m1 = np.array(range(16)).reshape(4, 2, 2)
m2 = m1.copy()
m3 = m1.copy()

reduce(np.matmul, (m1, m2, m3))

result - 4x2x2 array:结果 - 4x2x2 数组:

array([[[    6,    11],
        [   22,    39]],

       [[  514,   615],
        [  738,   883]],

       [[ 2942,  3267],
        [ 3630,  4031]],

       [[ 8826,  9503],
        [10234, 11019]]])

As you see, np.matmul treats 4x2x2 3-D arrays like 1-D arrays of 2x2 matrices.如您所见, np.matmul 将 4x2x2 3-D 数组视为 2x2 矩阵的 1-D 数组。 Can I do the same using np.linalg.multi_dot() instead of reduce(np.matmul) and, if yes, will it lead to any performance improvement?我可以使用 np.linalg.multi_dot() 而不是 reduce(np.matmul) 来做同样的事情,如果是,它会导致任何性能改进吗?

np.linalg.multi_dot() tries to optimize the operation by finding the order of dot products that leads to the fewest multiplications overall. np.linalg.multi_dot()尝试通过找到导致总体乘法最少的点积的顺序来优化操作。

As all your matrices are square, the order of dot products does not matter and you will always end up with the same number of multiplications.由于您所有的矩阵都是方阵,点积的顺序无关紧要,您将始终得到相同数量的乘法。

Internally, np.linalg.multi_dot() doesn't run any C code but merely calls out to np.dot() , so you can do the same:在内部, np.linalg.multi_dot()不运行任何 C 代码,而只是调用np.dot() ,因此您可以执行相同的操作:

functools.reduce(np.matmul, (m1, m2, m3))

or simply或者干脆

m1 @ m2 @ m3

你也可以使用np.einsum()

np.einsum('ijk,ikl,ilm->ijm',m1,m2,m3)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM