[英]Speed up Cython implementation of dot product multiplication
I'm trying to learn cython by trying to outperform Numpy at dot product operation np.dot(a,b)
.我试图通过在点积操作np.dot(a,b)
上超越 Numpy 来学习 cython。 But my implementation is about 4x slower.但是我的实现慢了大约 4 倍。
So, this is my hello.pyx file cython implementation:所以,这是我的 hello.pyx 文件 cython 实现:
cimport numpy as cnp
cnp.import_array()
cpdef double dot_product(double[::1] vect1, double[::1] vect2):
cdef int size = vect1.shape[0]
cdef double result = 0
cdef int i = 0
while i < size:
result += vect1[i] * vect2[i]
i += 1
return result
This is my.py test file:这是 my.py 测试文件:
import timeit
setup = '''
import numpy as np
import hello
n = 10000
a = np.array([float(i) for i in range(n)])
b = np.array([i/2 for i in a])
'''
lf_code = 'res_lf = hello.dot_product(a, b)'
np_code = 'res_np = np.dot(a,b)'
n = 100
lf_time = timeit.timeit(lf_code, setup=setup, number=n) * 100
np_time = timeit.timeit(np_code, setup=setup, number=n) * 100
print(f'Lightning fast time: {lf_time}.')
print(f'Numpy time: {np_time}.')
Console output:控制台 output:
Lightning fast time: 0.12186000000156127.
Numpy time: 0.028800000001183435.
Command to build hello.pyx:构建 hello.pyx 的命令:
python setup.py build_ext --inplace
setup.py file: setup.py 文件:
from distutils.core import Extension, setup
from Cython.Build import cythonize
import numpy as np
# define an extension that will be cythonized and compiled
ext = Extension(name="hello", sources=["hello.pyx"], include_dirs=[np.get_include()])
setup(ext_modules=cythonize(ext))
Processor: i7-7700T @ 2.90 GHz处理器:i7-7700T @ 2.90 GHz
The problem mainly comes from the lack of SIMD instructions (due to both the bound-checking and the inefficient default compiler flags) compared to Numpy (which use OpenBLAS on most platforms by default).与 Numpy(在大多数平台上默认使用 OpenBLAS)相比,问题主要来自缺乏 SIMD 指令(由于边界检查和低效的默认编译器标志)。
To fix that, you should first add the following line in the beginning of the hello.pix
file:要解决这个问题,您应该首先在hello.pix
文件的开头添加以下行:
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False
Then, you should use this new setup.py
file:然后,您应该使用这个新的setup.py
文件:
from distutils.core import Extension, setup
from Cython.Build import cythonize
import numpy as np
# define an extension that will be cythonized and compiled
ext = Extension(name="hello", sources=["hello.pyx"], include_dirs=[np.get_include()], extra_compile_args=['-O3', '-mavx', '-ffast-math'])
setup(ext_modules=cythonize(ext))
Note that the flags are dependent of the compiler.请注意,标志取决于编译器。 That being said, both Clang and GCC support them (and probably ICC too).话虽如此,Clang 和 GCC 都支持它们(可能也支持 ICC)。 -O3
tells Clang and GCC to use more aggressive optimization like the automatic vectorization of the code. -O3
告诉 Clang 和 GCC 使用更积极的优化,例如代码的自动矢量化。 -mavx
tells them to use the AVX instruction set (which is only available on relatively recent x86-64 processors). -mavx
告诉他们使用 AVX 指令集(仅在相对较新的 x86-64 处理器上可用)。 -ffast-math
tells them to assume that floating-point number operations are associative (which is not the case) and that you only use finite/basic numbers (no NaN, nor infinities). -ffast-math
告诉他们假设浮点数运算是关联的(事实并非如此),并且您只使用有限/基本数(没有 NaN,也没有无穷大)。 If the above assumption are not fulfilled, then the program can crash at runtime, so be careful about such flags.如果不满足上述假设,则程序可能会在运行时崩溃,因此请小心此类标志。
Note that OpenBLAS automatically selects the instruction set based on your machine and AFAIK it does not use -ffast-math
but a safer (low-level) alternative.请注意,OpenBLAS 会根据您的机器和 AFAIK 自动选择指令集,它不使用-ffast-math
而是更安全(低级)的替代方案。
Here are results on my machine:这是我机器上的结果:
Before optimization:
Lightning fast time: 0.10018469997703505.
Numpy time: 0.024747799989199848.
After (with GCC):
Lightning fast time: 0.02865879996534204.
Numpy time: 0.02456870001878997.
After (with Clang):
Lightning fast time: 0.01965239998753532.
Numpy time: 0.024799799984975834.
The code produced by Clang is faster than Numpy on my machine. Clang 生成的代码在我的机器上比 Numpy 快。
An analysis of the assembly code executed by the processor on my machine show that the code only use slow scalar instruction, contains unnecessary bound-checks and is mainly limited by the result +=...
operation (because of a loop carried dependency).对我机器上处理器执行的汇编代码的分析表明,该代码仅使用慢标量指令,包含不必要的边界检查,并且主要受result +=...
操作的限制(因为循环携带依赖)。
162e3:┌─→movsd xmm0,QWORD PTR [rbx+rax*8] # Load 1 item
162e8:│ mulsd xmm0,QWORD PTR [rsi+rax*8] # Load 1 item
162ed:│ addsd xmm1,xmm0 # Main bottleneck (accumulation)
162f1:│ cmp rdi,rax
162f4:│↓ je 163f8 # Bound checking conditional jump
162fa:│ cmp rdx,rax
162fd:│↓ je 16308 # Bound checking conditional jump
162ff:│ add rax,0x1
16303:├──cmp rcx,rax
16306:└──jne 162e3
Once optimized, the result is:优化后的结果是:
13720:┌─→vmovupd ymm3,YMMWORD PTR [r13+rax*1+0x0] # Load 4 items
13727:│ vmulpd ymm0,ymm3,YMMWORD PTR [rcx+rax*1] # Load 4 items
1372c:│ add rax,0x20
13730:│ vaddpd ymm1,ymm1,ymm0 # Still a bottleneck (but better)
13734:├──cmp rdx,rax
13737:└──jne 13720
The result +=...
operation is still the bottleneck in the optimized version but this is much better since the loop work on 4 items at once. result +=...
操作仍然是优化版本中的瓶颈,但这要好得多,因为循环一次可以处理 4 个项目。 To remove the bottleneck, the loop must be partially unrolled.要消除瓶颈,必须部分展开循环。 However, GCC (which is the default compiler on my machine) is not able to do that properly (even when ask to using -funrol-loops
(due to a loop-carried dependency). This is why OpenBLAS should be a bit faster than the code produced by GCC.但是,GCC(这是我机器上的默认编译器)无法正确执行此操作(即使在要求使用-funrol-loops
时(由于循环携带依赖)。这就是为什么 OpenBLAS 应该比GCC 生成的代码。
Hopefully, Clang is able to do that by default.希望 Clang 默认能够做到这一点。 Here is the code produced by Clang:下面是 Clang 生成的代码:
59e0:┌─→vmovupd ymm4,YMMWORD PTR [rax+rdi*8] # load 4 items
59e5:│ vmovupd ymm5,YMMWORD PTR [rax+rdi*8+0x20] # load 4 items
59eb:│ vmovupd ymm6,YMMWORD PTR [rax+rdi*8+0x40] # load 4 items
59f1:│ vmovupd ymm7,YMMWORD PTR [rax+rdi*8+0x60] # load 4 items
59f7:│ vmulpd ymm4,ymm4,YMMWORD PTR [rbx+rdi*8]
59fc:│ vaddpd ymm0,ymm4,ymm0
5a00:│ vmulpd ymm4,ymm5,YMMWORD PTR [rbx+rdi*8+0x20]
5a06:│ vaddpd ymm1,ymm4,ymm1
5a0a:│ vmulpd ymm4,ymm6,YMMWORD PTR [rbx+rdi*8+0x40]
5a10:│ vmulpd ymm5,ymm7,YMMWORD PTR [rbx+rdi*8+0x60]
5a16:│ vaddpd ymm2,ymm4,ymm2
5a1a:│ vaddpd ymm3,ymm5,ymm3
5a1e:│ add rdi,0x10
5a22:├──cmp rsi,rdi
5a25:└──jne 59e0
The code is not optimal (because it should unroll the loop at least 6 times due to the latency of the vaddpd
instruction), but it is very good.该代码不是最优的(因为由于vaddpd
指令的延迟,它应该至少展开循环 6 次),但它非常好。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.