[英]Numpy batch dot product
假設我有兩個向量並希望取它們的點積; 這很簡單,
import numpy as np
a = np.random.rand(3)
b = np.random.rand(3)
result = np.dot(a,b)
如果我有一堆向量並且我希望每個向量都有點,那么最天真的代碼是
# 5 = number of vectors
a = np.random.rand(5,3)
b = np.random.rand(5,3)
result = [np.dot(aa,bb) for aa, bb in zip(a,b)]
批量計算的兩種方法是使用乘法和求和,以及 einsum,
result = np.sum(a*b, axis=1)
# or
result = np.einsum('ij,ij->i', a, b)
但是,這些都沒有發送到 BLAS 后端,因此只使用一個內核。 當N
很大時,比如 100 萬,這並不是很好。
tensordot確實調度到 BLAS 后端。 使用 tensordot 進行此計算的一種糟糕方法是
np.diag(np.tensordot(a,b, axes=[1,1])
這很糟糕,因為它分配了一個N*N
矩陣,並且大部分元素都是浪費工作。
另一種(非常快)方法是隱藏的 inner1d function
from numpy.core.umath_tests import inner1d
result = inner1d(a,b)
但這似乎不可行,因為可能公開導出它的問題已經過時了。 這仍然歸結為在 C 中編寫循環,而不是使用多個內核。
有沒有辦法讓dot
、 matmul
或tensordot
在多個內核上一次完成所有這些點積?
首先,沒有直接的BLAS function可以做到這一點。 使用許多 1 級 BLAS 函數效率不高,因為使用多線程進行非常短時間的計算往往會引入相當大的開銷,並且不使用多線程可能不是最佳選擇。 盡管如此,這種計算主要是受內存限制的,因此它在具有許多內核的平台上擴展性很差(很少幾個內核通常足以使 memory 帶寬飽和)。
一個簡單的解決方案是使用Numexpr package 應該非常有效地做到這一點(它應該避免創建臨時 arrays 並且還應該使用多個線程)。 但是,在這種情況下,大陣列的性能有些令人失望。
最好的解決方案似乎是使用Numba (或 Cython)。 Numba 可以為小型和大型輸入 arrays 生成快速代碼,並且很容易並行化代碼。 但是請注意,管理線程會引入一個開銷,這對於小型陣列來說可能相當大(在某些多核平台上最多幾毫秒)。
這是一個 Numexpr 實現:
import numexpr as ne
expr = ne.NumExpr('sum(a * b, axis=1)')
result = expr.run(a, b)
這是一個(順序)Numba 實現:
import numba as nb
# Use `parallel=True` for a parallel implementation
@nb.njit('float64[:](float64[:,::1], float64[:,::1])')
def multiDots(a, b):
assert a.shape == b.shape
n, m = a.shape
res = np.empty(n, dtype=np.float64)
# Use `nb.prange` instead of `range` to run the loop in parallel
for i in range(n):
s = 0.0
for j in range(m):
s += a[i,j] * b[i,j]
res[i] = s
return res
result = multiDots(a, b)
以下是(舊)2 核機器上的一些基准測試:
On small 5x3 arrays:
np.einsum('ij,ij->i', a, b, optimize=True): 45.2 us
Numba (parallel): 12.1 us
np.sum(a*b, axis=1): 9.5 us
np.einsum('ij,ij->i', a, b): 6.5 us
Numexpr: 3.2 us
Numba (sequential): 1.3 us
On small 1000000x3 arrays:
np.sum(a*b, axis=1): 27.8 ms
Numexpr: 15.3 ms
np.einsum('ij,ij->i', a, b, optimize=True): 9.0 ms
np.einsum('ij,ij->i', a, b): 8.8 ms
Numba (sequential): 6.8 ms
Numba (parallel): 5.3 ms
順序 Numba 實現提供了很好的權衡。 如果你真的想要最好的性能,你可以使用開關。 但是,以獨立於平台的方式選擇最佳n
閾值並不是那么容易。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.