繁体   English   中英

numpy multi_dot比numpy.dot慢吗?

[英]How is numpy multi_dot slower than numpy.dot?

我正在尝试优化一些执行大量顺序矩阵运算的代码。

我认为numpy.linalg.multi_dotdocs此处 )将在C或BLAS中执行所有操作,因此比诸如arr1.dot(arr2).dot(arr3)类的arr1.dot(arr2).dot(arr3)要快得多。

我真的很惊讶在笔记本上运行以下代码:

v1 = np.random.rand(2,2)

v2 = np.random.rand(2,2)



%%timeit 
    ​    
v1.dot(v2.dot(v1.dot(v2)))

The slowest run took 9.01 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 3.14 µs per loop



%%timeit        ​

np.linalg.multi_dot([v1,v2,v1,v2])

The slowest run took 4.67 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 32.9 µs per loop

要发现使用multi_dot进行相同的操作要慢大约10倍。

我的问题是:

  • 我想念什么吗? 有什么意义吗?
  • 有没有其他方法可以优化顺序矩阵运算?
  • 我应该期望使用cython的行为相同吗?

这是因为您的测试矩阵太小且太规则。 确定最快评估顺序的开销可能会超过潜在的性能提升。

使用文档中的示例:

import numpy as snp
from numpy.linalg import multi_dot

# Prepare some data
A = np.random.rand(10000, 100)
B = np.random.rand(100, 1000)
C = np.random.rand(1000, 5)
D = np.random.rand(5, 333)

%timeit -n 10 multi_dot([A, B, C, D])
%timeit -n 10 np.dot(np.dot(np.dot(A, B), C), D)
%timeit -n 10 A.dot(B).dot(C).dot(D)

结果:

10 loops, best of 3: 12 ms per loop
10 loops, best of 3: 62.7 ms per loop
10 loops, best of 3: 59 ms per loop

multi_dot通过评估标量乘法最少的最快乘法顺序来提高性能。

在上述情况下,默认的常规乘法顺序((AB)C)D被评估为A((BC)D)这样,将1000x100 @ 100x1000乘法减少为1000x100 @ 100x333 ,减少了至少2/3标量乘法。

您可以通过测试验证

%timeit -n 10 np.dot(A, np.dot(np.dot(B, C), D))
10 loops, best of 3: 19.2 ms per loop

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM