简体   繁体   中英

Matrix multiplication of (n,n,M) and (n,n) matrix in numpy Python

I am using Python.

Let A.shape=(n,n,M) and B.shape=(n,n) I want to do the following:

AB = np.array_like(A)
for m in range(M):
    AB[:,:,m]=A[:,:,m] @ B

this code however does not seem like the most efficient way to do this?

One option is to use np.einsum :

np.einsum('ijk,jl->ilk', A, B)

Or transpose A twice:

(A.transpose(2,0,1) @ B).transpose(1,2,0)

Example :

>>> import numpy as np
>>> A = np.arange(12).reshape(2,2,3)
>>> B = np.arange(4).reshape(2,2)
>>> AB = np.zeros_like(A)

>>> M = 3
>>> for m in range(M):
...     AB[:,:,m]=A[:,:,m] @ B
...
>>> AB
array([[[ 6,  8, 10],
        [ 9, 13, 17]],

       [[18, 20, 22],
        [33, 37, 41]]])

# einsum
>>> np.einsum('ijk,jl->ilk', A, B)
array([[[ 6,  8, 10],
        [ 9, 13, 17]],

       [[18, 20, 22],
        [33, 37, 41]]])

# tranpose
>>> (A.transpose(2,0,1) @ B).transpose(1,2,0)
array([[[ 6,  8, 10],
        [ 9, 13, 17]],

       [[18, 20, 22],
        [33, 37, 41]]])

We can use np.tensordot -

np.tensordot(A,B,axes=(1,0)).swapaxes(1,2)

Related post to understand tensordot .

Under the hoods, it does reshaping , alignes axes by permuting and then uses BLAS based matrix-multiplication with np.dot . That dirty work would look something along these lines -

A.swapaxes(1,2).reshape(-1,n).dot(B).reshape(n,-1,n).swapaxes(1,2)

Starting off with B , it would be something like this -

B.T.dot(A.swapaxes(0,1).reshape(n,-1)).reshape(n,n,-1).swapaxes(0,1)

Benchmarking

Setup -

np.random.seed(0)
n,M = 50,50
A = np.random.rand(n,n,M)
B = np.random.rand(n,n)

Timings -

# @Psidom's soln-1
In [18]: %timeit np.einsum('ijk,jl->ilk', A, B)
100 loops, best of 3: 10.2 ms per loop

# @Psidom's soln-2
In [19]: %timeit (A.transpose(2,0,1) @ B).transpose(1,2,0)
100 loops, best of 3: 10.7 ms per loop

# @Psidom's einsum soln-1 with optimize set as True
In [20]: %timeit np.einsum('ijk,jl->ilk', A, B,optimize=True)
1000 loops, best of 3: 1.17 ms per loop

In [21]: %timeit np.tensordot(A,B,axes=(1,0)).swapaxes(1,2)
1000 loops, best of 3: 1.09 ms per loop

In [22]: %timeit A.swapaxes(1,2).reshape(-1,n).dot(B).reshape(n,-1,n).swapaxes(1,2)
1000 loops, best of 3: 1.03 ms per loop

In [23]: %timeit B.T.dot(A.swapaxes(0,1).reshape(n,-1)).reshape(n,n,-1).swapaxes(0,1)
1000 loops, best of 3: 951 µs per loop

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM