简体   繁体   English

如何使用 Numba 有效地加快简单的移动平均线计算

[英]How can I use Numba to efficiently speed up a simple Moving Average calculation

I am trying to use Numba to speed up some simple iterative functions for stock market analysis.我正在尝试使用 Numba 来加速股票市场分析的一些简单迭代函数。 I'm not interested in Pandas or Numpy for this, I'm just trying to understand what the approach would be for a no-python (@njit) function.我对 Pandas 或 Numpy 不感兴趣,我只是想了解无蟒蛇(@njit)function 的方法是什么。

Here is a simple moving average function:这是一个简单的移动平均线 function:

def sma_plain(src,p):
    win = []
    res = []
    for n in src:
        win.append(n)
        if len(win) > p:
            win = win[1:]
        res.append(sum(win)/len(win))
    return res

I can tell immediately that I can't set the empty win[] and res[] lists with Numba.我可以立即告诉我无法使用 Numba 设置空的 win[] 和 res[] 列表。 I tried to use Numba's List(), but that can't be initialized as empty within the function.我尝试使用 Numba 的 List(),但在 function 中无法将其初始化为空。 I tried starting by copying src for the output, and using a slice of src to instantiate the window (by window I mean a set of values that will be summed), but my code won't compile.我尝试从复制 output 的 src 开始,并使用一片 src 实例化 window(通过 window,我的意思是一组将被编译的值)。 Also in one of my attempts that did compile, timeit produced slower results compared to my original function.同样在我进行编译的尝试之一中,与我原来的 function 相比,timeit 产生的结果更慢。 Most likely because sum() can't be used.很可能是因为 sum() 不能使用。 The only example I could see in the docs was using gpuvectorization and with a syntax I didn't understand.我在文档中看到的唯一示例是使用 gpuvectorization 和我不理解的语法。 I'm not interested in that, yet, I simply want to understand the process with Numba for nopython.我对此不感兴趣,但是,我只是想了解 Numba for nopython 的过程。

Also I made the function where it uses a smaller window for the initial-values, ideally I would like to write None or null values to the list when the window hasn't matured, but it seems that Numba does not allow for null values. Also I made the function where it uses a smaller window for the initial-values, ideally I would like to write None or null values to the list when the window hasn't matured, but it seems that Numba does not allow for null values. This is ok, I think I could simply use a shorter list in this scenario and keep track of an offset when making calculations, but it would be great if there were an ability to track null values.没关系,我想我可以在这种情况下简单地使用较短的列表并在进行计算时跟踪偏移量,但如果能够跟踪 null 值,那就太好了。

This was my last attempt, but it's not really useful as it doesn't compile.这是我最后一次尝试,但它并没有真正有用,因为它不能编译。

@njit
def sma(src, p):
    res = src.copy()
    i = 0
    length = len(src)
    while i < length:
        win = src[max(0, i+1-p) : i+1]
        win_length = len(win)
        s = 0
        for n in win:
            s += n
        s /= win_length
        res[i] = s
    return res

Edit: I have a function that compiles and seems to be allowing for None values now.编辑:我有一个 function 可以编译并且现在似乎允许 None 值。 I don't know why it was giving me errors before.我不知道为什么它之前给了我错误。 So now I have the following similar function:所以现在我有以下类似的function:

@njit
def sma(src, p):
    slices = [src[i-p:i] for i in range(p,len(src)+1)]
    res = []
    for slc in slices:
        s = 0.0
        for i in range(p):
            s += slc[i]
        res.append(s/p)
    res = [None]*(p-1) + res
    return res

def sma_plain(src,p):
    win = []
    res = []
    for n in src:
        win.append(n)
        if len(win) > p:
            win = win[1:]
        if len(win) == p:
            res.append(sum(win)/len(win))
        else:
            res.append(None)
    return res

But for this with 1000 iterations over some stock data timeit reports 39 seconds for the numba function, and 7 seconds for the Python function.但是对于这个在一些股票数据时间上进行 1000 次迭代的情况,它报告 numba function 为 39 秒,Python ZC1C425268E68385D1AB5074F14ZA 为 7 秒。 Now I'm wondering if Numba is even working any longer as it isn't throwing errors for everything like it was.现在我想知道 Numba 是否还在工作,因为它不会像以前那样抛出错误。 Are there caching issues with the compiled functions that cause it to use outdated versions or something?编译函数是否存在缓存问题,导致它使用过时的版本或其他东西?

As long as you don't want to use Numpy, I won't recommend this one;-):只要你不想用Numpy,我就不推荐这个;-):

def sma_numpy_acc(a, p):
    m = np.cumsum(a) / p
    m[p:] -= m[:-p]
    m[:p-1] = np.nan
    return m    

Note I'm using NaN instead of None, son that the array can have homogeneous types.请注意,我使用的是 NaN 而不是 None,因为数组可以具有同质类型。

Timing compared to the original functions:与原始功能相比的时序:

%timeit sma(a, p)
88.6 ms ± 1.58 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit sma_plain(a, p)
2.18 s ± 65.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit sma_numpy_acc(a, p)
3.95 ms ± 56.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Speed can be slightly increased by jitting the function:通过jitting function可以稍微提高速度:

@nb.njit
def sma_numpy_acc_jit(a, p):
    m = np.cumsum(a) / p
    m[p:] = m[p:] - m[:-p]        # Odd behavior of -= in Numba
    m[:p - 1] = np.nan
    return m

%timeit sma_numpy_acc_jit(a, p)
3 ms ± 66.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

The same idea of using a cumulative sum, still using Numpy arrays, but not Numpy functions:使用累积和的相同想法,仍然使用 Numpy arrays,但不使用 Numpy 函数:

@nb.njit
def sma_jit_acc(a, p):
    acc = np.empty_like(a)
    acc[0] = a[0]
    n = len(a)
    for i in range(1, n):
        acc[i] = acc[i-1] + a[i]
    for i in range(n-1, p-1, -1):
        acc[i] = (acc[i] - acc[i-p]) / p
    acc[p-1] /= p
    for i in range(p-1):
        acc[i] = np.nan
    return acc

Timing is similar to the pure Numpy function.时序类似于纯Numpy function。

%timeit sma_jit_acc(a, p)
3.69 ms ± 119 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

The same method using lists.使用列表的相同方法。 No trace of Numpy: Numpy 没有踪迹:

@nb.njit
def sma_jit_acc_lists(a, p):
    n = len(a)
    acc = [math.nan] * n
    acc[0] = a[0]
    for i in range(1, n):
        acc[i] = acc[i-1] + a[i]
    for i in range(n-1, p-1, -1):
        acc[i] = (acc[i] - acc[i-p]) / p
    acc[p-1] /= p
    return acc

Timing is degraded by the use of lists:使用列表会降低时间:

%timeit sma_jit_acc_lists(a, p)
24.2 ms ± 1.47 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM