[英]PyTorch: What is numpy.linalg.multi_dot() equivalent in PyTorch
[英]use numpy.linalg.multi_dot for 3-dimensional arrays of (N, M, M) shape
这是将 np.linalg.multi_dot() 函数与 Nx2x2 数组(如 functools.reduce(np.matmul, Nx2x2_arrays))一起使用的合理方法吗? 请看下面的例子。
import numpy as np
from functools import reduce
m1 = np.array(range(16)).reshape(4, 2, 2)
m2 = m1.copy()
m3 = m1.copy()
reduce(np.matmul, (m1, m2, m3))
结果 - 4x2x2 数组:
array([[[ 6, 11],
[ 22, 39]],
[[ 514, 615],
[ 738, 883]],
[[ 2942, 3267],
[ 3630, 4031]],
[[ 8826, 9503],
[10234, 11019]]])
如您所见, np.matmul 将 4x2x2 3-D 数组视为 2x2 矩阵的 1-D 数组。 我可以使用 np.linalg.multi_dot() 而不是 reduce(np.matmul) 来做同样的事情,如果是,它会导致任何性能改进吗?
np.linalg.multi_dot()
尝试通过找到导致总体乘法最少的点积的顺序来优化操作。
由于您所有的矩阵都是方阵,点积的顺序无关紧要,您将始终得到相同数量的乘法。
在内部, np.linalg.multi_dot()
不运行任何 C 代码,而只是调用np.dot()
,因此您可以执行相同的操作:
functools.reduce(np.matmul, (m1, m2, m3))
或者干脆
m1 @ m2 @ m3
你也可以使用np.einsum()
:
np.einsum('ijk,ikl,ilm->ijm',m1,m2,m3)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.