简体   繁体   中英

use numpy.linalg.multi_dot for 3-dimensional arrays of (N, M, M) shape

Is that any reasonable way to use np.linalg.multi_dot() function with Nx2x2 arrays like functools.reduce(np.matmul, Nx2x2_arrays)? Please, see example below.

import numpy as np
from functools import reduce

m1 = np.array(range(16)).reshape(4, 2, 2)
m2 = m1.copy()
m3 = m1.copy()

reduce(np.matmul, (m1, m2, m3))

result - 4x2x2 array:

array([[[    6,    11],
        [   22,    39]],

       [[  514,   615],
        [  738,   883]],

       [[ 2942,  3267],
        [ 3630,  4031]],

       [[ 8826,  9503],
        [10234, 11019]]])

As you see, np.matmul treats 4x2x2 3-D arrays like 1-D arrays of 2x2 matrices. Can I do the same using np.linalg.multi_dot() instead of reduce(np.matmul) and, if yes, will it lead to any performance improvement?

np.linalg.multi_dot() tries to optimize the operation by finding the order of dot products that leads to the fewest multiplications overall.

As all your matrices are square, the order of dot products does not matter and you will always end up with the same number of multiplications.

Internally, np.linalg.multi_dot() doesn't run any C code but merely calls out to np.dot() , so you can do the same:

functools.reduce(np.matmul, (m1, m2, m3))

or simply

m1 @ m2 @ m3

你也可以使用np.einsum()

np.einsum('ijk,ikl,ilm->ijm',m1,m2,m3)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM