简体   繁体   English

将Numpy中的矩阵列表相乘

[英]Multiply together list of matrices in Numpy

I'm looking for an efficient way to multiply a list of matrices in Numpy.我正在寻找一种有效的方法来乘以 Numpy 中的矩阵列表。 I have a matrix like this:我有一个这样的矩阵:

import numpy as np
a = np.random.randn(1000, 4, 4)

I want to matrix-multiply along the long axis, so the result is a 4x4 matrix.我想沿长轴进行矩阵乘法,所以结果是一个 4x4 矩阵。 So clearly I can do:很明显我可以做到:

res = np.identity(4)
for ai in a:
    res = np.matmul(res, ai)

But this is super-slow.但这是超慢的。 Is there a faster way (perhaps using einsum or some other function that I don't fully understand yet)?有没有更快的方法(也许使用einsum或其他一些我还不完全理解的函数)?

A solution that requires log_2(n) for loop interations for stacks with size of powers of 2 could be一个需要log_2(n) for大小为 2 幂的堆栈的循环交互的解决方案可能是

while len(a) > 1:
    a = np.matmul(a[::2, ...], a[1::2, ...])

which essentially iteratively multiplies two neighbouring matrices together until there is only one matrix left, doing half of the remaining multiplications per iteration.它本质上是将两个相邻矩阵迭代地相乘,直到只剩下一个矩阵,每次迭代执行剩余乘法的一半。

res = A * B * C * D * ...         # 1024 remaining multiplications

becomes变成

res = (A * B) * (C * D) * ...     # 512 remaining multiplications

becomes变成

res = ((A * B) * (C * D)) * ...   # 256 remaining multiplications

etc.等等。

For non-powers of 2 you can do this for the first 2^n matrices and use your algorithm for the remaining matrices.对于 2 的非幂,您可以对前2^n矩阵执行此操作,并对其余矩阵使用您的算法。

np.linalg.multi_dot does this sort of chaining. np.linalg.multi_dot做这种链接。

In [119]: a = np.random.randn(5, 4, 4)
In [120]: res = np.identity(4)
In [121]: for ai in a: res = np.matmul(res, ai)
In [122]: res
Out[122]: 
array([[ -1.04341835,  -1.22015464,   9.21459712,   0.97214725],
       [ -0.13652679,   0.61012689,  -0.07325689,  -0.17834132],
       [ -2.45684401,  -1.76347514,  12.41094524,   1.00411347],
       [ -8.36738671,  -6.5010718 ,  15.32489832,   3.62426123]])
In [123]: np.linalg.multi_dot(a)
Out[123]: 
array([[ -1.04341835,  -1.22015464,   9.21459712,   0.97214725],
       [ -0.13652679,   0.61012689,  -0.07325689,  -0.17834132],
       [ -2.45684401,  -1.76347514,  12.41094524,   1.00411347],
       [ -8.36738671,  -6.5010718 ,  15.32489832,   3.62426123]])

But it is slower, 92.3 µs per loop v 22.2 µs per loop.但它更慢,每个循环 92.3 µs v 每个循环 22.2 µs。 And for your 1000 item case, the test timing is still running.对于您的 1000 件商品,测试计时仍在运行。

After figuring out some 'optimal order' multi_dot does a recursive dot .在找出一些“最佳顺序”之后, multi_dot做了一个递归dot

def _multi_dot(arrays, order, i, j):
    """Actually do the multiplication with the given order."""
    if i == j:
        return arrays[i]
    else:
        return dot(_multi_dot(arrays, order, i, order[i, j]),
                   _multi_dot(arrays, order, order[i, j] + 1, j))

In the 1000 item case this hits a recursion depth error.在 1000 项的情况下,这会遇到递归深度错误。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM