[英]Fast multiplication of a series of matrices
什么是最快的运行方式:
reduce(lambda x,y : x@y, ls)
在蟒蛇?
对于矩阵列表ls
。 我没有 Nvidia GPU,但我确实有很多 CPU 内核可以使用。 我以为我可以使该过程并行工作(将其拆分为对log
迭代),但似乎对于小( 1000x1000
)矩阵,这实际上是最糟糕的。 这是我试过的代码:
from multiprocessing import Pool
import numpy as np
from itertools import zip_longest
def matmul(x):
if x[1] is None:
return x[0]
return x[1]@x[0]
def fast_mul(ls):
while True:
n = len(ls)
if n == 0:
raise Exception("Splitting Error")
if n == 1:
return ls[0]
if n == 2:
return ls[1]@ls[0]
with Pool(processes=(n//2+1)) as pool:
ls = pool.map(matmul, list(zip_longest(*[iter(ls)]*2)))
编辑:抛出另一个可能的功能
编辑:我用np.linalg.multi_dot
添加了结果,预计它会比其他的更快,但实际上它以某种方式慢得多。 我想它是考虑到其他类型的用例的设计。
我不确定你能比这快得多。 以下是数据为 3D 方阵数组的情况下的几种不同的归约实现:
from multiprocessing import Pool
from functools import reduce
import numpy as np
import numba as nb
def matmul_n_naive(data):
return reduce(np.matmul, data)
# If you don't care about modifying data pass copy=False
def matmul_n_binary(data, copy=True):
if len(data) < 1:
raise ValueError
data = np.array(data, copy=copy)
n, r, c = data.shape
dt = data.dtype
s = 1
while (n + s - 1) // s > 1:
a = data[:n - s:2 * s]
b = data[s:n:2 * s]
np.matmul(a, b, out=a)
s *= 2
return np.array(a[0])
def matmul_n_pool(data):
if len(data) < 1:
raise ValueError
lst = data
with Pool() as pool:
while len(lst) > 1:
lst_next = pool.starmap(np.matmul, zip(lst[::2], lst[1::2]))
if len(lst) % 2 != 0:
lst_next.append(lst[-1])
lst = lst_next
return lst[0]
@nb.njit(parallel=False)
def matmul_n_numba_nopar(data):
res = np.eye(data.shape[1], data.shape[2], dtype=data.dtype)
for i in nb.prange(len(data)):
res = res @ data[i]
return res
@nb.njit(parallel=True)
def matmul_n_numba_par(data):
res = np.eye(data.shape[1], data.shape[2], dtype=data.dtype)
for i in nb.prange(len(data)): # Numba knows how to do parallel reductions correctly
res = res @ data[i]
return res
def matmul_n_multidot(data):
return np.linalg.multi_dot(data)
还有一个测试:
# Test
import numpy as np
np.random.seed(0)
a = np.random.rand(10, 100, 100) * 2 - 1
b1 = matmul_n_naive(a)
b2 = matmul_n_binary(a)
b3 = matmul_n_pool(a)
b4 = matmul_n_numba_nopar(a)
b5 = matmul_n_numba_par(a)
b6 = matmul_n_multidot(a)
print(np.allclose(b1, b2))
# True
print(np.allclose(b1, b3))
# True
print(np.allclose(b1, b4))
# True
print(np.allclose(b1, b5))
# True
print(np.allclose(b1, b6))
# True
这里有一些基准测试,似乎没有一致的赢家,但“天真的”解决方案在各方面都非常好,二进制和 Numba 各不相同,进程池不是很好,而且np.linalg.multi_dot
似乎不是很有利方阵。
import numpy as np
# 10 matrices 1000x1000
np.random.seed(0)
a = np.random.rand(10, 1000, 1000) * 0.1 - 0.05
%timeit matmul_n_naive(a)
# 121 ms ± 6.09 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit matmul_n_binary(a)
# 165 ms ± 3.68 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit matmul_n_numba_nopar(a)
# 108 ms ± 510 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit matmul_n_numba_par(a)
# 244 ms ± 7.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit matmul_n_multidot(a)
# 132 ms ± 2.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 200 matrices 100x100
np.random.seed(0)
a = np.random.rand(200, 100, 100) * 0.1 - 0.05
%timeit matmul_n_naive(a)
# 4.4 ms ± 226 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit matmul_n_binary(a)
# 13.4 ms ± 299 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit matmul_n_numba_nopar(a)
# 9.51 ms ± 126 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit matmul_n_numba_par(a)
# 4.93 ms ± 146 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit matmul_n_multidot(a)
# 1.14 s ± 22.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 300 matrices 10x10
np.random.seed(0)
a = np.random.rand(300, 10, 10) * 0.1 - 0.05
%timeit matmul_n_naive(a)
# 526 µs ± 953 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_binary(a)
# 152 µs ± 508 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit matmul_n_pool(a)
# 610 ms ± 5.93 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit matmul_n_numba_nopar(a)
# 239 µs ± 1.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_numba_par(a)
# 175 µs ± 422 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit matmul_n_multidot(a)
# 3.68 s ± 87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 1000 matrices 10x10
np.random.seed(0)
a = np.random.rand(1000, 10, 10) * 0.1 - 0.05
%timeit matmul_n_naive(a)
# 1.56 ms ± 4.49 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_binary(a)
# 392 µs ± 790 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_pool(a)
# 727 ms ± 12.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit matmul_n_numba_nopar(a)
# 589 µs ± 356 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_numba_par(a)
# 451 µs ± 1.68 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_multidot(a)
# Never finished...
有一个函数可以做到这一点: np.linalg.multi_dot
,据说针对最佳评估顺序进行了优化:
np.linalg.multi_dot(ls)
事实上,文档说的非常接近你原来的措辞:
将
multi_dot
视为:def multi_dot(arrays): return functools.reduce(np.dot, arrays)
您也可以尝试np.einsum
,它可以让您乘以多达 25 个矩阵:
from string import ascii_lowercase
ls = [...]
index = ','.join(ascii_lowercase[x:x + 2] for x in range(len(ls)))
index += f'->{index[0]}{index[-1]}'
np.einsum(index, *ls)
定时
简单案例:
ls = np.random.rand(100, 1000, 1000) - 0.5
%timeit reduce(lambda x, y : x @ y, ls)
4.3 s ± 76.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
4.35 s ± 84.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
4.86 s ± 68.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
5.24 s ± 66.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
更复杂的情况:
ls = [x.T if i % 2 else x for i, x in enumerate(np.random.rand(100, 2000, 500) - 0.5)]
%timeit reduce(lambda x, y : x @ y, ls)
7.94 s ± 96.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
7.91 s ± 33.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
9.38 s ± 111 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
2.03 s ± 52.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
请注意, multi_dot
完成的前期工作在简单的情况下有负面好处(更multi_dot
是, lambda
工作速度比原始运算符快),但在不太直接的情况下节省了 75% 的时间。
所以为了完整起见,这里是一个非方形的情况:
ls = [x.T if i % 2 else x for i, x in enumerate(np.random.rand(100, 400, 300) - 0.5)]
%timeit reduce(lambda x, y : x @ y, ls)
245 ms ± 8.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
245 ms ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
284 ms ± 12.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
638 ms ± 12.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
因此,对于大多数一般情况,您的原始reduce
调用实际上似乎与您需要的一样好。 我唯一的建议是使用operator.matmul
而不是 lambda。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.