from numpy import *
a = ones((2,3,N))
b = ones((3,2,N))
I want to matrix multiply b*a for each of the N matrices. I could do this
c = zeros((3,3,N))
for i in range(N):
c[:,:,i] = b[:,:,i].dot(a[:,:,i])
but it's slow for large N. Is there a fast way to do this in one line?
Just swap the axes around a bit so that N is the first dimension, and then you can multiply them straight. I'm also swapping c back to your desired shape of 3x3xN:
N = 10
a = np.ones((2,3,N))
b = np.ones((3,2,N))
a = a.swapaxes(1,2).swapaxes(0,1)
b = b.swapaxes(1,2).swapaxes(0,1)
c = (b@a).swapaxes(0,1).swapaxes(1,2)
print(c.shape)
>>> (3, 3, 10)
Another solution is transpose
:
N = 5
c = (b.transpose(2,0,1) @ a.transpose(2,0,1)).transpose(1,2,0)
c.shape
# (3,3,5)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.