简体   繁体   中英

Numpy batch dot product

Suppose I have two vectors and wish to take their dot product; this is simple,

import numpy as np

a = np.random.rand(3)
b = np.random.rand(3)

result = np.dot(a,b)

If I have stacks of vectors and I want each one dotted, the most naive code is

# 5 = number of vectors
a = np.random.rand(5,3)
b = np.random.rand(5,3)
result = [np.dot(aa,bb) for aa, bb in zip(a,b)]

Two ways to batch this computation are using a multiply and sum, and einsum,

result = np.sum(a*b, axis=1)

# or
result = np.einsum('ij,ij->i', a, b)

However, neither of these dispatch to the BLAS backend, and so use only a single core. This is not super great when N is very large, say 1 million.

tensordot does dispatch to the BLAS backend. A terrible way to do this computation with tensordot is

np.diag(np.tensordot(a,b, axes=[1,1])

This is terrible because it allocates an N*N matrix, and the majority of the elements are waste work.

Another (brilliantly fast) approach is the hidden inner1d function

from numpy.core.umath_tests import inner1d

result = inner1d(a,b)

but it seems this isn't going to be viable , since the issue that might export it publicly has gone stale. And this still boils down to writing the loop in C, instead of using multiple cores.

Is there a way to get dot , matmul , or tensordot to do all these dot products at once, on multiple cores?

First of all, there is no direct BLAS function to do that . Using many level 1 BLAS functions is not very efficient since using multiple thread for a very short-timed computation tends to introduce a pretty-big overhead and not using multiple thread may be sub-optimal. Still, such computation is mainly memory-bound and so it scales poorly on platform with many cores (few cores are often enough to saturate the memory bandwidth).

One simple solution is to use the Numexpr package which should do that quite efficiently (it should avoid the creation of temporary arrays and should also use multiple threads). However, the performance are somewhat disappointing for big array in this case.

The best solution appear to use Numba (or Cython). Numba can generate a fast code for both small and big input arrays and it is easy to parallelize the code. Please note however that managing threads introduces an overhead that can be quite big for small array (up to few ms on some many-core platforms).

Here is a Numexpr implementation:

import numexpr as ne
expr = ne.NumExpr('sum(a * b, axis=1)')
result = expr.run(a, b)

Here is a (sequential) Numba implementation:

import numba as nb

# Use `parallel=True` for a parallel implementation
@nb.njit('float64[:](float64[:,::1], float64[:,::1])')
def multiDots(a, b):
    assert a.shape == b.shape
    n, m = a.shape
    res = np.empty(n, dtype=np.float64)

    # Use `nb.prange` instead of `range` to run the loop in parallel
    for i in range(n):
        s = 0.0
        for j in range(m):
            s += a[i,j] * b[i,j]
        res[i] = s

    return res

result = multiDots(a, b)

Here are some benchmarks on a (old) 2-core machine:

On small 5x3 arrays:
    np.einsum('ij,ij->i', a, b, optimize=True):  45.2 us
    Numba (parallel):                            12.1 us
    np.sum(a*b, axis=1):                          9.5 us
    np.einsum('ij,ij->i', a, b):                  6.5 us
    Numexpr:                                      3.2 us
    Numba (sequential):                           1.3 us

On small 1000000x3 arrays:
    np.sum(a*b, axis=1):                         27.8 ms
    Numexpr:                                     15.3 ms
    np.einsum('ij,ij->i', a, b, optimize=True):   9.0 ms
    np.einsum('ij,ij->i', a, b):                  8.8 ms
    Numba (sequential):                           6.8 ms
    Numba (parallel):                             5.3 ms

The sequential Numba implementation gives a good trade-off. You can use a switch if you really want the best performance. Choosing the best n threshold in a platform-independent way is not so easy though.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM