繁体   English   中英

在numpy中计算矩阵乘积的迹线的最佳方法是什么?

[英]What is the best way to compute the trace of a matrix product in numpy?

如果我有numpy数组AB ,则可以使用以下公式计算其矩阵乘积的迹线:

tr = numpy.linalg.trace(A.dot(B))

但是,当轨迹中仅使用对角线元素时,矩阵乘法A.dot(B)不必要地计算矩阵乘积中的所有非对角线条目。 相反,我可以做类似的事情:

tr = 0.0
for i in range(n):
    tr += A[i, :].dot(B[:, i])

但这会在Python代码中执行循环,并且不如numpy.linalg.trace明显。

有没有更好的方法来计算numpy数组的矩阵乘积的轨迹? 最快或最惯用的方法是什么?

您可以通过仅将中间存储减少到对角线元素来改进@Bill的解决方案:

from numpy.core.umath_tests import inner1d

m, n = 1000, 500

a = np.random.rand(m, n)
b = np.random.rand(n, m)

# They all should give the same result
print np.trace(a.dot(b))
print np.sum(a*b.T)
print np.sum(inner1d(a, b.T))

%timeit np.trace(a.dot(b))
10 loops, best of 3: 34.7 ms per loop

%timeit np.sum(a*b.T)
100 loops, best of 3: 4.85 ms per loop

%timeit np.sum(inner1d(a, b.T))
1000 loops, best of 3: 1.83 ms per loop

另一个选择是使用np.einsum并且根本没有显式的中间存储:

# Will print the same as the others:
print np.einsum('ij,ji->', a, b)

在我的系统上,它的运行速度比使用inner1d慢一些,但可能不适用于所有系统,请参见以下问题

%timeit np.einsum('ij,ji->', a, b)
100 loops, best of 3: 1.91 ms per loop

您可以从Wikipedia中使用hadamard乘积(逐元素乘法)来计算跟踪:

# Tr(A.B)
tr = (A*B.T).sum()

我认为这需要比numpy.trace(A.dot(B))少的计算。

编辑:

跑一些计时器。 这种方法比使用numpy.trace

In [37]: timeit("np.trace(A.dot(B))", setup="""import numpy as np;
A, B = np.random.rand(1000,1000), np.random.rand(1000,1000)""", number=100)
Out[38]: 8.6434469223022461

In [39]: timeit("(A*B.T).sum()", setup="""import numpy as np;
A, B = np.random.rand(1000,1000), np.random.rand(1000,1000)""", number=100)
Out[40]: 0.5516049861907959

需要注意的是一个轻微的变体采取的点积vec torized矩阵。 在python中,矢量化是使用.flatten('F') 这比在我的计算机上将Hadamard产品的总和稍慢一些,因此,它比wflynny的解决方案差,但我认为这很有趣,因为在某些情况下,它可以更直观。 例如,我个人发现对于矩阵正态分布,矢量化解对我来说更容易理解。

在我的系统上的速度比较:

import numpy as np
import time

N = 1000

np.random.seed(123)
A = np.random.randn(N, N)
B = np.random.randn(N, N)

tart = time.time()
for i in range(10):
    C = np.trace(A.dot(B))
print(time.time() - start, C)

start = time.time()
for i in range(10):
    C = A.flatten('F').dot(B.T.flatten('F'))
print(time.time() - start, C)

start = time.time()
for i in range(10):
    C = (A.T * B).sum()
print(time.time() - start, C)

start = time.time()
for i in range(10):
    C = (A * B.T).sum()
print(time.time() - start, C)

结果:

6.246593236923218 -629.370798672
0.06539678573608398 -629.370798672
0.057890892028808594 -629.370798672
0.05709719657897949 -629.370798672

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM