繁体   English   中英

矩阵乘以一个numpy矩阵阵列

[英]Matrix multiply a numpy array of matrices

我正在扩展代码,旨在对2个向量执行一个函数,以便它代替处理2个向量数组。 我正在使用numpy.dot来获取两个3x3矩阵的乘积。 现在我想用3x3矩阵数组来做这个。 我无法弄清楚如何用numpy.einsum做这个,但我认为这就是我需要的,我只是在努力去理解它是如何工作的。

这是我想要使用循环的示例。 有没有办法在没有循环的情况下做到这一点?

>>> import numpy as np
>>> m = np.arange(27).reshape(3,3,3)
>>> print m
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]],

       [[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]]])
>>> m2 = np.zeros(m.shape)
>>> for i in length(m):
        m2[i] = np.dot(m[i],m[i])
>>> print m2
    array([[[   15.,    18.,    21.],
            [   42.,    54.,    66.],
            [   69.,    90.,   111.]],

           [[  366.,   396.,   426.],
            [  474.,   513.,   552.],
            [  582.,   630.,   678.]],

           [[ 1203.,  1260.,  1317.],
            [ 1392.,  1458.,  1524.],
            [ 1581.,  1656.,  1731.]]])

我在这篇文章Python中发现了一个numpy.einsum语法,numpy,einsum乘以一堆矩阵 ,它可以满足我的需求。 我不清楚它为什么有用,并且想要了解如何构造索引字符串以供将来使用。

>>> print np.einsum('fij,fjk->fik', V, V)
    [[[  15   18   21]
      [  42   54   66]
      [  69   90  111]]

     [[ 366  396  426]
      [ 474  513  552]
      [ 582  630  678]]

     [[1203 1260 1317]
      [1392 1458 1524]
      [1581 1656 1731]]]

你也可以使用熊猫。 在下面的示例中,'p'相当于您的'm',并且是数据的3D表示。 使用列表推导,p2计算每个矩阵的点积。 为了进行比较,然后将结果转换回numpy数组列表。

import pandas as pd

%%timeit
m = np.arange(27).reshape(3,3,3)
p = pd.Panel(m)
p2 = [p[i].dot(p[i]) for i in p.items]

1000 loops, best of 3: 846 µs per loop

m2 = [p2[i].values for i in p2.items]
print(m2)

[array([[ 15,  18,  21],
       [ 42,  54,  66],
       [ 69,  90, 111]]), 
array([[366, 396, 426],
       [474, 513, 552],
       [582, 630, 678]]), 
array([[1203, 1260, 1317],
       [1392, 1458, 1524],
       [1581, 1656, 1731]])]

然而,Numpy要快得多。

%%timeit
np.einsum('fij,fjk->fik', m, m)

100000 loops, best of 3: 5.01 µs per loop

直接将它与np.dot进行比较:

%%timeit
[np.dot(m[i], m[i]) for i in range(len(m))]

100000 loops, best of 3: 6.78 µs per loop

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM