简体   繁体   English

加快点积乘法的 Cython 实现

[英]Speed up Cython implementation of dot product multiplication

I'm trying to learn cython by trying to outperform Numpy at dot product operation np.dot(a,b) .我试图通过在点积操作np.dot(a,b)上超越 Numpy 来学习 cython。 But my implementation is about 4x slower.但是我的实现慢了大约 4 倍。

So, this is my hello.pyx file cython implementation:所以,这是我的 hello.pyx 文件 cython 实现:

cimport numpy as cnp

cpdef double dot_product(double[::1] vect1, double[::1] vect2):
    cdef int size = vect1.shape[0]
    cdef double result = 0
    cdef int i = 0
    while i < size:
        result += vect1[i] * vect2[i]
        i += 1
    return result

This is my.py test file:这是 my.py 测试文件:

import timeit

setup = '''
import numpy as np
import hello

n = 10000
a = np.array([float(i) for i in range(n)])
b = np.array([i/2 for i in a])
lf_code = 'res_lf = hello.dot_product(a, b)'
np_code = 'res_np = np.dot(a,b)'
n = 100
lf_time = timeit.timeit(lf_code, setup=setup, number=n) * 100
np_time = timeit.timeit(np_code, setup=setup, number=n) * 100

print(f'Lightning fast time: {lf_time}.')
print(f'Numpy time: {np_time}.')

Console output:控制台 output:

Lightning fast time: 0.12186000000156127.
Numpy time: 0.028800000001183435.

Command to build hello.pyx:构建 hello.pyx 的命令:

python setup.py build_ext --inplace

setup.py file: setup.py 文件:

from distutils.core import Extension, setup
from Cython.Build import cythonize
import numpy as np

# define an extension that will be cythonized and compiled
ext = Extension(name="hello", sources=["hello.pyx"], include_dirs=[np.get_include()])

Processor: i7-7700T @ 2.90 GHz处理器:i7-7700T @ 2.90 GHz

The problem mainly comes from the lack of SIMD instructions (due to both the bound-checking and the inefficient default compiler flags) compared to Numpy (which use OpenBLAS on most platforms by default).与 Numpy(在大多数平台上默认使用 OpenBLAS)相比,问题主要来自缺乏 SIMD 指令(由于边界检查和低效的默认编译器标志)。

To fix that, you should first add the following line in the beginning of the hello.pix file:要解决这个问题,您应该首先在hello.pix文件的开头添加以下行:

#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False

Then, you should use this new setup.py file:然后,您应该使用这个新的setup.py文件:

from distutils.core import Extension, setup
from Cython.Build import cythonize
import numpy as np

# define an extension that will be cythonized and compiled
ext = Extension(name="hello", sources=["hello.pyx"], include_dirs=[np.get_include()], extra_compile_args=['-O3', '-mavx', '-ffast-math'])

Note that the flags are dependent of the compiler.请注意,标志取决于编译器。 That being said, both Clang and GCC support them (and probably ICC too).话虽如此,Clang 和 GCC 都支持它们(可能也支持 ICC)。 -O3 tells Clang and GCC to use more aggressive optimization like the automatic vectorization of the code. -O3告诉 Clang 和 GCC 使用更积极的优化,例如代码的自动矢量化。 -mavx tells them to use the AVX instruction set (which is only available on relatively recent x86-64 processors). -mavx告诉他们使用 AVX 指令集(仅在相对较新的 x86-64 处理器上可用)。 -ffast-math tells them to assume that floating-point number operations are associative (which is not the case) and that you only use finite/basic numbers (no NaN, nor infinities). -ffast-math告诉他们假设浮点数运算是关联的(事实并非如此),并且您只使用有限/基本数(没有 NaN,也没有无穷大)。 If the above assumption are not fulfilled, then the program can crash at runtime, so be careful about such flags.如果不满足上述假设,则程序可能会在运行时崩溃,因此请小心此类标志。

Note that OpenBLAS automatically selects the instruction set based on your machine and AFAIK it does not use -ffast-math but a safer (low-level) alternative.请注意,OpenBLAS 会根据您的机器和 AFAIK 自动选择指令集,它不使用-ffast-math而是更安全(低级)的替代方案。


Here are results on my machine:这是我机器上的结果:

Before optimization:
  Lightning fast time: 0.10018469997703505.
  Numpy time: 0.024747799989199848.

After (with GCC):
  Lightning fast time: 0.02865879996534204.
  Numpy time: 0.02456870001878997.

After (with Clang):
  Lightning fast time: 0.01965239998753532.
  Numpy time: 0.024799799984975834.

The code produced by Clang is faster than Numpy on my machine. Clang 生成的代码在我的机器上比 Numpy 快

Under the hood在引擎盖下

An analysis of the assembly code executed by the processor on my machine show that the code only use slow scalar instruction, contains unnecessary bound-checks and is mainly limited by the result +=... operation (because of a loop carried dependency).对我机器上处理器执行的汇编代码的分析表明,该代码仅使用慢标量指令,包含不必要的边界检查,并且主要受result +=...操作的限制(因为循环携带依赖)。

162e3:┌─→movsd  xmm0,QWORD PTR [rbx+rax*8]  # Load 1 item
162e8:│  mulsd  xmm0,QWORD PTR [rsi+rax*8]  # Load 1 item
162ed:│  addsd  xmm1,xmm0                   # Main bottleneck (accumulation)
162f1:│  cmp    rdi,rax
162f4:│↓ je     163f8                       # Bound checking conditional jump
162fa:│  cmp    rdx,rax
162fd:│↓ je     16308                       # Bound checking conditional jump
162ff:│  add    rax,0x1
16303:├──cmp    rcx,rax
16306:└──jne    162e3

Once optimized, the result is:优化后的结果是:

13720:┌─→vmovupd      ymm3,YMMWORD PTR [r13+rax*1+0x0]    # Load 4 items
13727:│  vmulpd       ymm0,ymm3,YMMWORD PTR [rcx+rax*1]   # Load 4 items
1372c:│  add          rax,0x20
13730:│  vaddpd       ymm1,ymm1,ymm0        # Still a bottleneck (but better)
13734:├──cmp          rdx,rax
13737:└──jne          13720

The result +=... operation is still the bottleneck in the optimized version but this is much better since the loop work on 4 items at once. result +=...操作仍然是优化版本中的瓶颈,但这要好得多,因为循环一次可以处理 4 个项目。 To remove the bottleneck, the loop must be partially unrolled.要消除瓶颈,必须部分展开循环。 However, GCC (which is the default compiler on my machine) is not able to do that properly (even when ask to using -funrol-loops (due to a loop-carried dependency). This is why OpenBLAS should be a bit faster than the code produced by GCC.但是,GCC(这是我机器上的默认编译器)无法正确执行此操作(即使在要求使用-funrol-loops时(由于循环携带依赖)。这就是为什么 OpenBLAS 应该比GCC 生成的代码。

Hopefully, Clang is able to do that by default.希望 Clang 默认能够做到这一点。 Here is the code produced by Clang:下面是 Clang 生成的代码:

59e0:┌─→vmovupd      ymm4,YMMWORD PTR [rax+rdi*8]       # load 4 items
59e5:│  vmovupd      ymm5,YMMWORD PTR [rax+rdi*8+0x20]  # load 4 items
59eb:│  vmovupd      ymm6,YMMWORD PTR [rax+rdi*8+0x40]  # load 4 items
59f1:│  vmovupd      ymm7,YMMWORD PTR [rax+rdi*8+0x60]  # load 4 items
59f7:│  vmulpd       ymm4,ymm4,YMMWORD PTR [rbx+rdi*8]
59fc:│  vaddpd       ymm0,ymm4,ymm0
5a00:│  vmulpd       ymm4,ymm5,YMMWORD PTR [rbx+rdi*8+0x20]
5a06:│  vaddpd       ymm1,ymm4,ymm1
5a0a:│  vmulpd       ymm4,ymm6,YMMWORD PTR [rbx+rdi*8+0x40]
5a10:│  vmulpd       ymm5,ymm7,YMMWORD PTR [rbx+rdi*8+0x60]
5a16:│  vaddpd       ymm2,ymm4,ymm2
5a1a:│  vaddpd       ymm3,ymm5,ymm3
5a1e:│  add          rdi,0x10
5a22:├──cmp          rsi,rdi
5a25:└──jne          59e0

The code is not optimal (because it should unroll the loop at least 6 times due to the latency of the vaddpd instruction), but it is very good.该代码不是最优的(因为由于vaddpd指令的延迟,它应该至少展开循环 6 次),但它非常好。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

粤ICP备18138465号  © 2020-2024 STACKOOM.COM