繁体   English   中英

一系列矩阵的快速乘法

[英]Fast multiplication of a series of matrices

什么是最快的运行方式:

    reduce(lambda x,y : x@y, ls)

在蟒蛇?

对于矩阵列表ls 我没有 Nvidia GPU,但我确实有很多 CPU 内核可以使用。 我以为我可以使该过程并行工作(将其拆分为对log迭代),但似乎对于小( 1000x1000 )矩阵,这实际上是最糟糕的。 这是我试过的代码:

from multiprocessing import Pool
import numpy as np
from itertools import zip_longest

def matmul(x):
    if x[1] is None:
        return x[0]
    return x[1]@x[0]

def fast_mul(ls):
    while True:
        
        n = len(ls)
        if n == 0:
            raise Exception("Splitting Error")
        if n == 1:
            return ls[0]
        if n == 2:
            return ls[1]@ls[0]

        with Pool(processes=(n//2+1)) as pool:
            ls = pool.map(matmul, list(zip_longest(*[iter(ls)]*2)))
    

编辑:抛出另一个可能的功能

编辑:我用np.linalg.multi_dot添加了结果,预计它会比其他的更快,但实际上它以某种方式慢得多。 我想它是考虑到其他类型的用例的设计。


我不确定你能比这快得多。 以下是数据为 3D 方阵数组的情况下的几种不同的归约实现:

from multiprocessing import Pool
from functools import reduce
import numpy as np
import numba as nb

def matmul_n_naive(data):
    return reduce(np.matmul, data)

# If you don't care about modifying data pass copy=False
def matmul_n_binary(data, copy=True):
    if len(data) < 1:
        raise ValueError
    data = np.array(data, copy=copy)
    n, r, c = data.shape
    dt = data.dtype
    s = 1
    while (n + s - 1) // s > 1:
        a = data[:n - s:2 * s]
        b = data[s:n:2 * s]
        np.matmul(a, b, out=a)
        s *= 2
    return np.array(a[0])

def matmul_n_pool(data):
    if len(data) < 1:
        raise ValueError
    lst = data
    with Pool() as pool:
        while len(lst) > 1:
            lst_next = pool.starmap(np.matmul, zip(lst[::2], lst[1::2]))
            if len(lst) % 2 != 0:
                lst_next.append(lst[-1])
            lst = lst_next
    return lst[0]

@nb.njit(parallel=False)
def matmul_n_numba_nopar(data):
    res = np.eye(data.shape[1], data.shape[2], dtype=data.dtype)
    for i in nb.prange(len(data)):
        res = res @ data[i]
    return res

@nb.njit(parallel=True)
def matmul_n_numba_par(data):
    res = np.eye(data.shape[1], data.shape[2], dtype=data.dtype)
    for i in nb.prange(len(data)):  # Numba knows how to do parallel reductions correctly
        res = res @ data[i]
    return res

def matmul_n_multidot(data):
    return np.linalg.multi_dot(data)

还有一个测试:

# Test
import numpy as np

np.random.seed(0)
a = np.random.rand(10, 100, 100) * 2 - 1
b1 = matmul_n_naive(a)
b2 = matmul_n_binary(a)
b3 = matmul_n_pool(a)
b4 = matmul_n_numba_nopar(a)
b5 = matmul_n_numba_par(a)
b6 = matmul_n_multidot(a)
print(np.allclose(b1, b2))
# True
print(np.allclose(b1, b3))
# True
print(np.allclose(b1, b4))
# True
print(np.allclose(b1, b5))
# True
print(np.allclose(b1, b6))
# True

这里有一些基准测试,似乎没有一致的赢家,但“天真的”解决方案在各方面都非常好,二进制和 Numba 各不相同,进程池不是很好,而且np.linalg.multi_dot似乎不是很有利方阵。

import numpy as np

# 10 matrices 1000x1000
np.random.seed(0)
a = np.random.rand(10, 1000, 1000) * 0.1 - 0.05
%timeit matmul_n_naive(a)
# 121 ms ± 6.09 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit matmul_n_binary(a)
# 165 ms ± 3.68 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit matmul_n_numba_nopar(a)
# 108 ms ± 510 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit matmul_n_numba_par(a)
# 244 ms ± 7.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit matmul_n_multidot(a)
# 132 ms ± 2.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# 200 matrices 100x100
np.random.seed(0)
a = np.random.rand(200, 100, 100) * 0.1 - 0.05
%timeit matmul_n_naive(a)
# 4.4 ms ± 226 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit matmul_n_binary(a)
# 13.4 ms ± 299 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit matmul_n_numba_nopar(a)
# 9.51 ms ± 126 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit matmul_n_numba_par(a)
# 4.93 ms ± 146 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit matmul_n_multidot(a)
# 1.14 s ± 22.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# 300 matrices 10x10
np.random.seed(0)
a = np.random.rand(300, 10, 10) * 0.1 - 0.05
%timeit matmul_n_naive(a)
# 526 µs ± 953 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_binary(a)
# 152 µs ± 508 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit matmul_n_pool(a)
# 610 ms ± 5.93 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit matmul_n_numba_nopar(a)
# 239 µs ± 1.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_numba_par(a)
# 175 µs ± 422 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit matmul_n_multidot(a)
# 3.68 s ± 87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# 1000 matrices 10x10
np.random.seed(0)
a = np.random.rand(1000, 10, 10) * 0.1 - 0.05
%timeit matmul_n_naive(a)
# 1.56 ms ± 4.49 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_binary(a)
# 392 µs ± 790 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_pool(a)
# 727 ms ± 12.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit matmul_n_numba_nopar(a)
# 589 µs ± 356 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_numba_par(a)
# 451 µs ± 1.68 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_multidot(a)
# Never finished...

有一个函数可以做到这一点: np.linalg.multi_dot ,据说针对最佳评估顺序进行了优化:

np.linalg.multi_dot(ls)

事实上,文档说的非常接近你原来的措辞:

multi_dot视为:

 def multi_dot(arrays): return functools.reduce(np.dot, arrays)

您也可以尝试np.einsum ,它可以让您乘以多达 25 个矩阵:

from string import ascii_lowercase

ls = [...]
index = ','.join(ascii_lowercase[x:x + 2] for x in range(len(ls)))
index += f'->{index[0]}{index[-1]}'
np.einsum(index, *ls)

定时

简单案例:

ls = np.random.rand(100, 1000, 1000) - 0.5

%timeit reduce(lambda x, y : x @ y, ls)
4.3 s ± 76.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
4.35 s ± 84.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
4.86 s ± 68.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
5.24 s ± 66.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

更复杂的情况:

ls = [x.T if i % 2 else x for i, x in enumerate(np.random.rand(100, 2000, 500) - 0.5)]

%timeit reduce(lambda x, y : x @ y, ls)
7.94 s ± 96.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
7.91 s ± 33.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
9.38 s ± 111 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
2.03 s ± 52.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

请注意, multi_dot完成的前期工作在简单的情况下有负面好处(更multi_dot是, lambda工作速度比原始运算符快),但在不太直接的情况下节省了 75% 的时间。

所以为了完整起见,这里是一个非方形的情况:

ls = [x.T if i % 2 else x for i, x in enumerate(np.random.rand(100, 400, 300) - 0.5)]

%timeit reduce(lambda x, y : x @ y, ls)
245 ms ± 8.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
245 ms ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
284 ms ± 12.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
638 ms ± 12.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

因此,对于大多数一般情况,您的原始reduce调用实际上似乎与您需要的一样好。 我唯一的建议是使用operator.matmul而不是 lambda。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM