简体   繁体   中英

How to mutiply arrays of matrices in Numpy

from numpy import *

a = ones((2,3,N))
b = ones((3,2,N))

I want to matrix multiply b*a for each of the N matrices. I could do this

c = zeros((3,3,N))
for i in range(N):
    c[:,:,i] = b[:,:,i].dot(a[:,:,i])

but it's slow for large N. Is there a fast way to do this in one line?

Just swap the axes around a bit so that N is the first dimension, and then you can multiply them straight. I'm also swapping c back to your desired shape of 3x3xN:

N = 10
a = np.ones((2,3,N))
b = np.ones((3,2,N))
a = a.swapaxes(1,2).swapaxes(0,1)
b = b.swapaxes(1,2).swapaxes(0,1)
c = (b@a).swapaxes(0,1).swapaxes(1,2)
print(c.shape)
>>> (3, 3, 10)

Another solution is transpose :

N = 5

c = (b.transpose(2,0,1) @ a.transpose(2,0,1)).transpose(1,2,0)

c.shape
# (3,3,5)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM