簡體   English   中英

Numba代碼比純python慢

[英]Numba code slower than pure python

我一直在努力加快粒子濾波器的重采樣計算。 由於python有很多方法可以加速它,我雖然會嘗試所有這些。 不幸的是,numba版本非常慢。 由於Numba應該加速,我認為這是我的錯誤。

我嘗試了4個不同的版本:

  1. Numba
  2. 蟒蛇
  3. NumPy的
  4. 用Cython

每個代碼如下:

import numpy as np
import scipy as sp
import numba as nb
from cython_resample import cython_resample

@nb.autojit
def numba_resample(qs, xs, rands):
    n = qs.shape[0]
    lookup = np.cumsum(qs)
    results = np.empty(n)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results

def python_resample(qs, xs, rands):
    n = qs.shape[0]
    lookup = np.cumsum(qs)
    results = np.empty(n)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results

def numpy_resample(qs, xs, rands):
    results = np.empty_like(qs)
    lookup = sp.cumsum(qs)
    for j, key in enumerate(rands):
        i = sp.argmax(lookup>key)
        results[j] = xs[i]
    return results

#The following is the code for the cython module. It was compiled in a
#separate file, but is included here to aid in the question.
"""
import numpy as np
cimport numpy as np
cimport cython

DTYPE = np.float64

ctypedef np.float64_t DTYPE_t

@cython.boundscheck(False)
def cython_resample(np.ndarray[DTYPE_t, ndim=1] qs, 
             np.ndarray[DTYPE_t, ndim=1] xs, 
             np.ndarray[DTYPE_t, ndim=1] rands):
    if qs.shape[0] != xs.shape[0] or qs.shape[0] != rands.shape[0]:
        raise ValueError("Arrays must have same shape")
    assert qs.dtype == xs.dtype == rands.dtype == DTYPE

    cdef unsigned int n = qs.shape[0]
    cdef unsigned int i, j 
    cdef np.ndarray[DTYPE_t, ndim=1] lookup = np.cumsum(qs)
    cdef np.ndarray[DTYPE_t, ndim=1] results = np.zeros(n, dtype=DTYPE)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results
"""

if __name__ == '__main__':
    n = 100
    xs = np.arange(n, dtype=np.float64)
    qs = np.array([1.0/n,]*n)
    rands = np.random.rand(n)

    print "Timing Numba Function:"
    %timeit numba_resample(qs, xs, rands)
    print "Timing Python Function:"
    %timeit python_resample(qs, xs, rands)
    print "Timing Numpy Function:"
    %timeit numpy_resample(qs, xs, rands)
    print "Timing Cython Function:"
    %timeit cython_resample(qs, xs, rands)

這導致以下輸出:

Timing Numba Function:
1 loops, best of 3: 8.23 ms per loop
Timing Python Function:
100 loops, best of 3: 2.48 ms per loop
Timing Numpy Function:
1000 loops, best of 3: 793 µs per loop
Timing Cython Function:
10000 loops, best of 3: 25 µs per loop

知道為什么numba代碼如此之慢? 我認為它至少可以與Numpy相媲美。

注意:如果有人對如何加速Numpy或Cython代碼示例有任何想法,那也不錯:)我的主要問題是關於Numba。

問題是numba無法直觀地lookup類型。 如果在方法中放置了print nb.typeof(lookup) ,你會發現numba將它視為一個對象,這很慢。 通常我會在locals dict中定義lookup類型,但是我遇到了一個奇怪的錯誤。 相反,我只是創建了一個小包裝器,以便我可以顯式定義輸入和輸出類型。

@nb.jit(nb.f8[:](nb.f8[:]))
def numba_cumsum(x):
    return np.cumsum(x)

@nb.autojit
def numba_resample2(qs, xs, rands):
    n = qs.shape[0]
    #lookup = np.cumsum(qs)
    lookup = numba_cumsum(qs)
    results = np.empty(n)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results

然后我的時間是:

print "Timing Numba Function:"
%timeit numba_resample(qs, xs, rands)

print "Timing Revised Numba Function:"
%timeit numba_resample2(qs, xs, rands)

Timing Numba Function:
100 loops, best of 3: 8.1 ms per loop
Timing Revised Numba Function:
100000 loops, best of 3: 15.3 µs per loop

如果你使用jit而不是autojit你甚至可以更快一點:

@nb.jit(nb.f8[:](nb.f8[:], nb.f8[:], nb.f8[:]))

對我來說,它從15.3微秒降低到12.5微秒,但它仍然令人印象深刻的autojit做得如何。

更快的numpy版本(與numpy_resample相比加速10倍)

def numpy_faster(qs, xs, rands):
    lookup = np.cumsum(qs)
    mm = lookup[None,:]>rands[:,None]
    I = np.argmax(mm,1)
    return xs[I]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM