Is that any reasonable way to use np.linalg.multi_dot() function with Nx2x2 arrays like functools.reduce(np.matmul, Nx2x2_arrays)? Please, see example below.
import numpy as np
from functools import reduce
m1 = np.array(range(16)).reshape(4, 2, 2)
m2 = m1.copy()
m3 = m1.copy()
reduce(np.matmul, (m1, m2, m3))
result - 4x2x2 array:
array([[[ 6, 11],
[ 22, 39]],
[[ 514, 615],
[ 738, 883]],
[[ 2942, 3267],
[ 3630, 4031]],
[[ 8826, 9503],
[10234, 11019]]])
As you see, np.matmul treats 4x2x2 3-D arrays like 1-D arrays of 2x2 matrices. Can I do the same using np.linalg.multi_dot() instead of reduce(np.matmul) and, if yes, will it lead to any performance improvement?
np.linalg.multi_dot()
tries to optimize the operation by finding the order of dot products that leads to the fewest multiplications overall.
As all your matrices are square, the order of dot products does not matter and you will always end up with the same number of multiplications.
Internally, np.linalg.multi_dot()
doesn't run any C code but merely calls out to np.dot()
, so you can do the same:
functools.reduce(np.matmul, (m1, m2, m3))
or simply
m1 @ m2 @ m3
你也可以使用np.einsum()
:
np.einsum('ijk,ikl,ilm->ijm',m1,m2,m3)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.