[英]Why is matrix multiplication with Numba slow?
我试图找到一个解释,为什么我与 Numba 的矩阵乘法比使用 NumPy 的点 function 慢得多。 尽管我使用最基本的代码来编写带有 Numba 的矩阵乘法 function,但我不认为性能明显变慢是由于算法。 为简单起见,我考虑两个 kxk 方阵,A 和 B。我的代码如下
1 @njit('f8[:,:](f8[:,:], f8[:,:])')
2 def numba_dot(A, B):
3
4 k=A.shape[1]
5 C = np.zeros((k, k))
6
7 for i in range(k):
8 for j in range(k):
9
10 tmp = 0.
11 for l in range(k):
12 tmp += A[i, l] * B[l, j]
13
14 C[i, j] = tmp
15
16 return C
使用两个随机矩阵 1000 x 1000 矩阵重复运行此代码,通常至少需要大约 1.5 秒才能完成。 另一方面,如果我不更新矩阵 C,即如果我删除第 14 行,或者为了测试而将其替换为例如以下行:
14 C[i, j] = i * j
代码在大约 1-5 毫秒内完成。 相比之下,NumPy 的点 function 需要大约 10 毫秒的矩阵乘法。
上述矩阵乘法代码与这个小变化之间的运行时间差异背后的原因是什么? 有没有办法在不显着降低代码性能的情况下将变量 tmp 的值存储在 C[i, j] 中?
本机NumPy
实现适用于矢量化操作。 如果您的 CPU 支持这些,则处理速度会快得多。 当前的微处理器具有片上矩阵乘法,它对数据传输和向量操作进行流水线化。
您的实现执行 k^3 循环迭代; 十亿的任何事情都需要一些不平凡的时间。 您的代码指定您要单独执行每个单元一个单元的操作,十亿个不同的操作,而不是并行和流水线完成的大约 5k 个操作。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.