繁体   English   中英

在 Python(和 Cython)中计算两个矩阵的点积的最快方法是什么

[英]What is the fastest way to compute the dot product of two matrices in Python (and Cython)

我正在尝试用 Cython 制作一个 Python 库,我需要在其中实现点积。 我有一个非常简单的计算点积的方法,但是,对于较大的矩阵,它的运行速度不够快。

我花了很多时间在谷歌上搜索这个问题,并试图让它尽可能快地运行,但是我无法让它运行得更快。

下面的代码显示了我目前如何计算它的 Python 实现:

a = [[1, 2, 3], [4, 5, 6]]
b = [[1], [2], [3]]

def dot(a, b):
    c = [[0 for j in range(len(b[i]))] for i in range(len(a))]

    for i in range(len(c)):
        for j in range(len(c[i])):
            t = 0
            for k in range(len(b)):
                t += a[i][k] * b[k][j]
            c[i][j] = t
    return c

print(dot(a, b))
# [[14], [32]]

这确实给出了正确的计算结果( python [[14], [32]] ),但是,计算我将要使用它的内容需要很长时间。 任何有关我如何加快速度的帮助将不胜感激。 谢谢

您可以为此使用numpy Numpy 实现了 BLAS 规范(基本线性代数子程序),它们是线性代数库的低级例程(如矩阵乘法)的事实上的标准。 要获得两个矩阵的点积,例如AB您可以使用以下代码:

A = [[1, 2, 3], [4, 5, 6]]
B = [[1], [2], [3]]

import numpy as np #Import numpy

numpy_a = np.array(A) #Cast your nested lists to numpy arrays
numpy_b = np.array(B)
print(np.dot(numpy_a, numpy_b)) #Print the result

根据结构的索引成本,您可能会通过分解一些操作来提高速度:

def dot(a, b):
    c = [[0 for j in range(len(b[i]))] for i in range(len(a))]
    bt = transpose(b)        # can this be done once cheaply?
    for i in range(len(c)):
        a1 = a[i]
        c1 = c[i]
        for j in range(len(c1)):
            b1 = bt[j]
            t = 0
            for k in range(len(b)):
                t += a1[k] * b1[k]
            c1[j] = t
    return c

可以用惯用的 Python 编写内部k循环:

for a2, b2 in zip(a1, b1):
     t += a2 * b2

我不知道这在 cython 翻译中是否更快。

快速 cython 还需要将各种变量定义为intfloat等,因此它可以进行直接的c转换,而不是通过通用但昂贵的 Python 对象。 我不会尝试重复 cython 文档。

您应该注释(静态类型)所有可能的变量。 以下是我的解决方案,如果您愿意:

# mydot.pyx
import numpy as np
cimport cython

def dot_1(a, b):
    c = [[0 for j in range(len(b[i]))] for i in range(len(a))]

    for i in range(len(c)):
        for j in range(len(c[i])):
            t = 0
            for k in range(len(b)):
                t += a[i][k] * b[k][j]
            c[i][j] = t
    return c


@cython.boundscheck(False)  # turn off bounds-checking
@cython.wraparound(False)  # turn off negative index wrapping
def dot_2(double[:, :] A, double[:, :] B):
    cdef Py_ssize_t M = A.shape[0]
    cdef Py_ssize_t Na = A.shape[1]
    cdef Py_ssize_t Nb = B.shape[0]
    cdef Py_ssize_t K = B.shape[1]

    assert Na == Nb

    result = np.empty((M, K), dtype='d')
    cdef double[:, :] C = result

    cdef double t

    for m in range(M):
        for k in range(K):
            t = 0
            for n in range(Na):
                t += A[m, n] * B[n, k]
            C[m, k] = t

    return result

# app.py
import pyximport
from numpy import array
from scipy import median
from timeit import repeat

pyximport.install()
from mydot import dot_1, dot_2


a = array([[1, 2, 3], [4, 5, 6]], dtype='d')
b = array([[1], [2], [3]], dtype='d')

dot_1_t = repeat('dot_1(a, b)', repeat=1000, number=1, globals=globals())
dot_2_t = repeat('dot_2(a, b)', repeat=1000, number=1, globals=globals())

print(f'dot_1 took {median(dot_1_t)*1000} ms.')
print(f'dot_2 took {median(dot_2_t)*1000} ms.')

当您运行cython --annotate mydot.pyx ,Cython 将生成一个 HTML 文件来注释 Cython 代码。 在那里,黄色高光越暗,生成的 C 代码的 (Python) 开销越多。 您可以将两个解决方案(尤其是for循环)相互比较。

运行python app.py也应该给你更快的结果。 当然,如果您提供低于某个阈值的较小尺寸的输入,您将不会看到两者之间的有意义的速度差异,因为您没有足够的迭代。 然而,在某个阈值之后,速度差异应该是显着的,因为循环中的每次迭代对于您的版本来说都是昂贵的(参见较深的黄线)。

最后一点是,正如这个问题下的每个人都已经建议的那样,当您提供具有更大维度的矩阵时, numpy的函数应该具有更高的性能——它们使用来自底层 BLAS 和 LAPACK 实现的阻塞(子)矩阵操作而不是天真地一个一个地迭代索引。

PS:如果您想不仅在double s 上而且在其他有意义的算术类型(例如int s 和float s)上专门化dot_2 ,您应该检查 Cython 的fused types

编辑。 因为我的回答后来被选为答案,所以我想举一个更大尺寸输入的例子。 如果不是上面的app.py ,而是使用以下内容:

# app.py
import pyximport
from numpy import array, random as rnd
from scipy import median
from timeit import repeat

pyximport.install()
from mydot import dot_1, dot_2


M = 100
N = 100
K = 1

a = rnd.randn(M, N)
b = rnd.randn(N, K)

dot_1_t = repeat('dot_1(a, b)', repeat=1000, number=1, globals=globals())
dot_2_t = repeat('dot_2(a, b)', repeat=1000, number=1, globals=globals())

print(f'dot_1 took {median(dot_1_t)*1000} ms.')
print(f'dot_2 took {median(dot_2_t)*1000} ms.')

时间应该类似于以下内容:

dot_1 took 5.218300502747297 ms.
dot_2 took 0.013017997844144702 ms.

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM