简体   繁体   中英

Apply matrix dot between a list of matrices and a list of vectors in Numpy

Let's suppose I have these two variables

matrices = np.random.rand(4,3,3)
vectors = np.random.rand(4,3,1)

What I would like to perform is the following:

dot_products = [matrix @ vector for (matrix,vector) in zip(matrices,vectors)]

Therefore, I've tried using the np.tensordot method, which at first seemed to make sense, but this happened when testing

>>> np.tensordot(matrices,vectors,axes=([-2,-1],[-2,-1]))
...
ValueError: shape-mismatch for sum 
>>> np.tensordot(matrices,vectors,axes=([-2,-1]))
...
ValueError: shape-mismatch for sum 

Is it possible to achieve these multiple dot products with the mentioned Numpy method? If not, is there another way that I can accomplish this using Numpy?

The documentation for @ is found at np.matmul . It is specifically designed for this kind of 'batch' processing:

In [76]: matrices = np.random.rand(4,3,3)
    ...: vectors = np.random.rand(4,3,1)

In [77]: dot_products = [matrix @ vector for (matrix,vector) in zip(matrices,vectors)]
In [79]: np.array(dot_products).shape
Out[79]: (4, 3, 1)

In [80]: (matrices @ vectors).shape
Out[80]: (4, 3, 1)

In [81]: np.allclose(np.array(dot_products), matrices@vectors)
Out[81]: True

A couple of problems with tensordot . The axes parameter specify which dimensions are summed, "dotted", In your case it would be the last of matrices and 2nd to the last of vectors . That's the standard dot paring.

In [82]: np.dot(matrices, vectors).shape
Out[82]: (4, 3, 4, 1)
In [84]: np.tensordot(matrices, vectors, (-1,-2)).shape
Out[84]: (4, 3, 4, 1)

You tried to specify 2 pairs of axes for summing. Also dot/tensordot does a kind of outer product on the other dimensions. You'd have to take the "diagonal" on the 4's. tensordot is not what you want for this operation.

We can be more explicit about the dimensions with einsum :

In [83]: np.einsum('ijk,ikl->ijl',matrices, vectors).shape
Out[83]: (4, 3, 1)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM