[英]Why dgemm (Cython compiled) is slower than numpy.dot
长话短说,我在Cython
构建了一个简单的乘法函数,调用scipy.linalg.cython_blas.dgemm
,编译它并针对基准Numpy.dot
运行它。 我听说过关于性能提升 50% 到 100 倍的神话,当我使用静态定义、数组维度预分配、内存视图、关闭检查等技巧时,我会看到。但后来我编写了自己的my_dot
函数(编译后),它比默认的Numpy.dot
慢 4 倍。 我真的不知道是什么原因,所以我只能猜测:
1) BLAS
库未链接
2)可能有一些我没有发现的内存开销
3) dot
使用了一些隐藏的魔法
4) setup.py
写得不好, c
代码没有优化编译
5) 我的my_dot
函数没有高效编写
下面是我的代码片段和我能想到的所有相关信息,这些信息可能有助于解决这个难题。 我很感激是否有人可以提供一些关于我做错了什么的见解,或者如何将性能提高到至少与默认的Numpy.dot
文件 1: model_cython/multi.pyx
。 您还需要文件夹中的model_cython/init.py
。
#cython: language_level=3
#cython: boundscheck=False
#cython: nonecheck=False
#cython: wraparound=False
#cython: infertypes=True
#cython: initializedcheck=False
#cython: cdivision=True
#distutils: extra_compile_args = -Wno-unused-function -Wno-unneeded-internal-declaration
from scipy.linalg.cython_blas cimport dgemm
import numpy as np
from numpy cimport ndarray, float64_t
from numpy cimport PyArray_ZEROS
cimport numpy as np
cimport cython
np.import_array()
ctypedef float64_t DOUBLE
def my_dot(double [::1, :] a, double [::1, :] b, int ashape0, int ashape1,
int bshape0, int bshape1):
cdef np.npy_intp cshape[2]
cshape[0] = <np.npy_intp> ashape0
cshape[1] = <np.npy_intp> bshape1
cdef:
int FORTRAN = 1
ndarray[DOUBLE, ndim=2] c = PyArray_ZEROS(2, cshape, np.NPY_DOUBLE, FORTRAN)
cdef double alpha = 1.0
cdef double beta = 0.0
dgemm("N", "N", &ashape0, &bshape1, &ashape1, &alpha, &a[0,0], &ashape0, &b[0,0], &bshape0, &beta, &c[0,0], &ashape0)
return c
文件 2: model_cython/example.py
。 执行基准测试的脚本
setup_str = """
import numpy as np
from numpy import float64
from multi import my_dot
a = np.ones((2,3), dtype=float64, order='F')
b = np.ones((3,2), dtype=float64, order='F')
print(a.flags)
ashape0, ashape1 = a.shape
bshape0, bshape1 = b.shape
"""
import timeit
print(timeit.timeit(stmt='c=my_dot(a,b, ashape0, ashape1, bshape0, bshape1)', setup=setup_str, number=100000))
print(timeit.timeit(stmt='c=a.dot(b)', setup=setup_str, number=100000))
文件 3: setup.py
。 编译.so
文件
from distutils.core import setup, Extension
from Cython.Build import cythonize
from Cython.Distutils import build_ext
import numpy
import os
basepath = os.path.dirname(os.path.realpath(__file__))
numpy_path = numpy.get_include()
package_name = 'multi'
setup(
name='multi',
cmdclass={'build_ext': build_ext},
ext_modules=[Extension(package_name,
[os.path.join(basepath, 'model_cython', 'multi.pyx')],
include_dirs=[numpy_path],
)],
)
文件 4: run.sh
。 执行setup.py
并移动内容的 Shell 脚本
python3 setup.py build_ext --inplace
path=$(pwd)
rm -r build
mv $path/multi.cpython-37m-darwin.so $path/model_cython/
rm $path/model_cython/multi.c
下面是编译消息的截图:
关于BLAS
,我的Numpy
在/usr/local/lib
正确链接到它,并且clang -bundle
似乎也在编译中添加-L/usr/local/lib
。 但也许这还不够?
Cython 擅长优化循环(这在 Python 中通常很慢),也是调用 C 的便捷方式(这正是您想要做的)。 但是,从 Python 调用 Cython 函数可能相对较慢 - 特别是因为您指定的所有类型都需要检查一致性。 因此,您通常会尝试在一个 Cython 调用之后隐藏大量工作,因此开销很小。
您几乎选择了最坏的情况:大量调用背后的一小部分工作。 Cython 或np.dot
是否会有更多的开销是相当随意的,但无论哪种方式,你正在测量的都是这个,而不是np.dot
与 BLAS dgemm
。
从您的评论看来,您实际上想要对两个 3D 数组的前两个维度进行点积。 因此,一个更有用的测试是尝试重现它。 以下是三个版本:
def einsum_mult(a,b):
# use np.einsum, won't benefit from Cython
return np.einsum("ijh,jkh->ikh",a,b)
def manual_mult(a,b):
# multiply one matrix at a time with numpy dot
# (could probably be optimized a bit with Cython)
c = np.empty((a.shape[0],b.shape[1],a.shape[2]),
dtype=np.float64, order='F')
for n in range(a.shape[2]):
c[:,:,n] = a[:,:,n].dot(b[:,:,n])
return c
def blas_version(double[::1,:,:] a,double[::1,:,:] b):
# uses dgemm
cdef double[::1,:,:] c = np.empty((a.shape[0], b.shape[1], a.shape[2]),
dtype=np.float64, order='F')
cdef double[::1,:] c_part
cdef int n
cdef double alpha = 1.0
cdef double beta = 0.0
cdef int ashape0 = a.shape[0], ashape1 = a.shape[1], bshape0 = b.shape[0], bshape1 = b.shape[1]
assert a.shape[2]==b.shape[2]
assert a.shape[1]==b.shape[0]
for n in range(a.shape[2]):
c_part = c[:,:,n]
dgemm("N", "N", &ashape0, &bshape1, &ashape1, &alpha, &a[0,0,n], &ashape0,
&b[0,0,n], &bshape0, &beta, &c_part[0,0], &ashape0)
return c
对于大小为(2,3,10000)
和(3,2,10000)
数组,重复 100 次,我得到:
manual_mult 1.6531286190001993 s (i.e. quite bad)
einsum 0.3215398370011826 s (pretty good)
blas_version 0.15762194800481666 s (best, pretty close to the "myth" performance gain you mention)
如果您充分利用 Cython 并将循环保留在编译代码中,则 BLAS 版本会很快。 (我没有花费任何精力来优化它,因此如果您尝试,您可能会击败它,但这只是为了说明这一点)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.