簡體   English   中英

加快點積乘法的 Cython 實現

[英]Speed up Cython implementation of dot product multiplication

我試圖通過在點積操作np.dot(a,b)上超越 Numpy 來學習 cython。 但是我的實現慢了大約 4 倍。

所以,這是我的 hello.pyx 文件 cython 實現:

cimport numpy as cnp
cnp.import_array()

cpdef double dot_product(double[::1] vect1, double[::1] vect2):
    cdef int size = vect1.shape[0]
    cdef double result = 0
    cdef int i = 0
    while i < size:
        result += vect1[i] * vect2[i]
        i += 1
    return result

這是 my.py 測試文件:

import timeit

setup = '''
import numpy as np
import hello

n = 10000
a = np.array([float(i) for i in range(n)])
b = np.array([i/2 for i in a])
'''
lf_code = 'res_lf = hello.dot_product(a, b)'
np_code = 'res_np = np.dot(a,b)'
n = 100
lf_time = timeit.timeit(lf_code, setup=setup, number=n) * 100
np_time = timeit.timeit(np_code, setup=setup, number=n) * 100

print(f'Lightning fast time: {lf_time}.')
print(f'Numpy time: {np_time}.')

控制台 output:

Lightning fast time: 0.12186000000156127.
Numpy time: 0.028800000001183435.

構建 hello.pyx 的命令:

python setup.py build_ext --inplace

setup.py 文件:

from distutils.core import Extension, setup
from Cython.Build import cythonize
import numpy as np

# define an extension that will be cythonized and compiled
ext = Extension(name="hello", sources=["hello.pyx"], include_dirs=[np.get_include()])
setup(ext_modules=cythonize(ext))

處理器:i7-7700T @ 2.90 GHz

與 Numpy(在大多數平台上默認使用 OpenBLAS)相比,問題主要來自缺乏 SIMD 指令(由於邊界檢查和低效的默認編譯器標志)。

要解決這個問題,您應該首先在hello.pix文件的開頭添加以下行:

#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False

然后,您應該使用這個新的setup.py文件:

from distutils.core import Extension, setup
from Cython.Build import cythonize
import numpy as np

# define an extension that will be cythonized and compiled
ext = Extension(name="hello", sources=["hello.pyx"], include_dirs=[np.get_include()], extra_compile_args=['-O3', '-mavx', '-ffast-math'])
setup(ext_modules=cythonize(ext))

請注意,標志取決於編譯器。 話雖如此,Clang 和 GCC 都支持它們(可能也支持 ICC)。 -O3告訴 Clang 和 GCC 使用更積極的優化,例如代碼的自動矢量化。 -mavx告訴他們使用 AVX 指令集(僅在相對較新的 x86-64 處理器上可用)。 -ffast-math告訴他們假設浮點數運算是關聯的(事實並非如此),並且您只使用有限/基本數(沒有 NaN,也沒有無窮大)。 如果不滿足上述假設,則程序可能會在運行時崩潰,因此請小心此類標志。

請注意,OpenBLAS 會根據您的機器和 AFAIK 自動選擇指令集,它不使用-ffast-math而是更安全(低級)的替代方案。


結果:

這是我機器上的結果:

Before optimization:
  Lightning fast time: 0.10018469997703505.
  Numpy time: 0.024747799989199848.

After (with GCC):
  Lightning fast time: 0.02865879996534204.
  Numpy time: 0.02456870001878997.

After (with Clang):
  Lightning fast time: 0.01965239998753532.
  Numpy time: 0.024799799984975834.

Clang 生成的代碼在我的機器上比 Numpy 快


在引擎蓋下

對我機器上處理器執行的匯編代碼的分析表明,該代碼僅使用慢標量指令,包含不必要的邊界檢查,並且主要受result +=...操作的限制(因為循環攜帶依賴)。

162e3:┌─→movsd  xmm0,QWORD PTR [rbx+rax*8]  # Load 1 item
162e8:│  mulsd  xmm0,QWORD PTR [rsi+rax*8]  # Load 1 item
162ed:│  addsd  xmm1,xmm0                   # Main bottleneck (accumulation)
162f1:│  cmp    rdi,rax
162f4:│↓ je     163f8                       # Bound checking conditional jump
162fa:│  cmp    rdx,rax
162fd:│↓ je     16308                       # Bound checking conditional jump
162ff:│  add    rax,0x1
16303:├──cmp    rcx,rax
16306:└──jne    162e3

優化后的結果是:

13720:┌─→vmovupd      ymm3,YMMWORD PTR [r13+rax*1+0x0]    # Load 4 items
13727:│  vmulpd       ymm0,ymm3,YMMWORD PTR [rcx+rax*1]   # Load 4 items
1372c:│  add          rax,0x20
13730:│  vaddpd       ymm1,ymm1,ymm0        # Still a bottleneck (but better)
13734:├──cmp          rdx,rax
13737:└──jne          13720

result +=...操作仍然是優化版本中的瓶頸,但這要好得多,因為循環一次可以處理 4 個項目。 要消除瓶頸,必須部分展開循環。 但是,GCC(這是我機器上的默認編譯器)無法正確執行此操作(即使在要求使用-funrol-loops時(由於循環攜帶依賴)。這就是為什么 OpenBLAS 應該比GCC 生成的代碼。

希望 Clang 默認能夠做到這一點。 下面是 Clang 生成的代碼:

59e0:┌─→vmovupd      ymm4,YMMWORD PTR [rax+rdi*8]       # load 4 items
59e5:│  vmovupd      ymm5,YMMWORD PTR [rax+rdi*8+0x20]  # load 4 items
59eb:│  vmovupd      ymm6,YMMWORD PTR [rax+rdi*8+0x40]  # load 4 items
59f1:│  vmovupd      ymm7,YMMWORD PTR [rax+rdi*8+0x60]  # load 4 items
59f7:│  vmulpd       ymm4,ymm4,YMMWORD PTR [rbx+rdi*8]
59fc:│  vaddpd       ymm0,ymm4,ymm0
5a00:│  vmulpd       ymm4,ymm5,YMMWORD PTR [rbx+rdi*8+0x20]
5a06:│  vaddpd       ymm1,ymm4,ymm1
5a0a:│  vmulpd       ymm4,ymm6,YMMWORD PTR [rbx+rdi*8+0x40]
5a10:│  vmulpd       ymm5,ymm7,YMMWORD PTR [rbx+rdi*8+0x60]
5a16:│  vaddpd       ymm2,ymm4,ymm2
5a1a:│  vaddpd       ymm3,ymm5,ymm3
5a1e:│  add          rdi,0x10
5a22:├──cmp          rsi,rdi
5a25:└──jne          59e0

該代碼不是最優的(因為由於vaddpd指令的延遲,它應該至少展開循環 6 次),但它非常好。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM