简体   繁体   English

numba中的性能嵌套循环

[英]Performance nested loop in numba

For performance reasons, I have started to use Numba besides NumPy. 出于性能原因,除了NumPy之外,我还开始使用Numba。 My Numba algorithm is working, but I have the feeling that it should be faster. 我的Numba算法正在运行,但是我觉得它应该更快。 There is one point which is slowing it down. 有一点使它放慢了速度。 Here is the code snippet: 这是代码片段:

@nb.njit
def rfunc1(ws, a, l):
    gn = a**l
    for x1 in range(gn):
        for x2 in range(gn):
            for x3 in range(gn):
                y = 0.0
                for i in range(1, l):
                    if numpy.all(ws[x1][0:i] == ws[x2][0:i]) and
                    numpy.all(ws[x1][i:l] == ws[x3][i:l]):
                        y += 1
                    if numpy.all(ws[x1][0:i] == ws[x2][0:i]) and 
                    numpy.all(ws[x1][i:l] == ws[x3][i:l]):
                        y += 1

In my opinion the if command is slowing it down. 我认为if命令会降低它的速度。 Is there a better way? 有没有更好的办法? (What I try to achieve here is related to a previous posted problem: Count possibilites for single crossovers ) ws is a NumPy array of size (gn, l) containing 0 's and 1 's (我在此处尝试实现的功能与先前发布的问题有关: 单交叉的计数可能性ws是大小为(gn, l)的NumPy数组,包含01

Given the logic of wanting to ensure all items are equal, you can take advantage of the fact that if any are not equal, you can short-circuit (ie stop comparing) the calculation. 鉴于希望确保所有项目相等的逻辑,您可以利用以下事实:如果有任何一项不相等,则可以使计算短路(即停止比较)。 I modified your original function slightly so that (1) you don't repeat the same comparison twice, and (2) sum y over the all nested loops so there was a return that could be compared: 我稍微修改了原始函数,以使(1)您不会重复相同的比较两次,并且(2)在所有嵌套循环中求和,因此可以比较返回值:

@nb.njit
def rfunc1(ws, a, l):
    gn = a**l
    ysum = 0
    for x1 in range(gn):
        for x2 in range(gn):
            for x3 in range(gn):
                y = 0.0
                for i in range(1, l):
                    if np.all(ws[x1][0:i] == ws[x2][0:i]) and np.all(ws[x1][i:l] == ws[x3][i:l]):
                        y += 1
                        ysum += 1

    return ysum


@nb.njit
def rfunc2(ws, a, l):
    gn = a**l
    ysum = 0
    for x1 in range(gn):
        for x2 in range(gn):
            for x3 in range(gn):
                y = 0.0
                for i in range(1, l):

                    incr_y = True
                    for j in range(i):
                        if ws[x1,j] != ws[x2,j]:
                            incr_y = False
                            break

                    if incr_y is True:
                        for j in range(i,l):
                            if ws[x1,j] != ws[x3,j]:
                                incr_y = False
                                break
                    if incr_y is True:
                        y += 1
                        ysum += 1
    return ysum

I don't know what the complete function looks like, but hopefully this helps you get started on the right path. 我不知道完整的功能是什么样子,但是希望这可以帮助您开始正确的道路。

Now for some timings: 现在是一些时间:

l = 7
a = 2
gn = a**l
ws = np.random.randint(0,2,size=(gn,l))
In [23]:

%timeit rfunc1(ws, a , l)
1 loop, best of 3: 2.11 s per loop


%timeit rfunc2(ws, a , l)
1 loop, best of 3: 39.9 ms per loop

In [27]: rfunc1(ws, a , l)
Out[27]: 131919

In [30]: rfunc2(ws, a , l)
Out[30]: 131919

That gives you a 50x speed-up. 这使您的速度提高了50倍。

Instead of just "having a feeling" where your bottleneck is, why not profile your code and find exactly where? 为何不仅仅“感觉”瓶颈在哪里,为什么不分析您的代码并确切地找到位置?

The first aim of profiling is to test a representative system to identify what's slow (or using too much RAM, or causing too much disk I/O or network I/O). 进行概要分析的首要目的是测试一个代表性的系统,以找出运行缓慢的系统(或使用过多的RAM,或引起过多的磁盘I / O或网络I / O)。

Profiling typically adds an overhead (10x to 100x slowdowns can be typical), and you still want your code to be used as similarly to in a real-world situation as possible. 分析通常会增加开销(通常会降低10到100倍的速度),并且您仍然希望代码的使用尽可能类似于实际情况。 Extract a test case and isolate the piece of the system that you need to test. 提取一个测试用例,并隔离您需要测试的系统部分。 Preferably, it'll have been written to be in its own set of modules already. 最好将其编写为已经在其自己的模块集中。

Basic techniques include the %timeit magic in IPython, time.time(), and a timing decorator (see example below). 基本技术包括IPython中的%timeit魔术, time.time(),timing decorator (请参见下面的示例)。 You can use these techniques to understand the behavior of statements and functions. 您可以使用这些技术来了解语句和函数的行为。

Then you have cProfile which will give you a high-level view of the problem so you can direct your attention to the critical functions. 然后,您有了cProfile ,可以从更高层次查看问题,因此您可以将注意力转移到关键功能上。

Next, look at line_profiler, which will profile your chosen functions on a line-by-line basis. 接下来,查看line_profiler,它将line_profiler,分析所选功能。 The result will include a count of the number of times each line is called and the percentage of time spent on each line. 结果将包括每条线被调用的次数以及每条线所花费的时间百分比的计数。 This is exactly the information you need to understand what's running slowly and why. 这正是您了解运行缓慢以及原因的信息。

perf stat helps you understand the number of instructions that are ultimately executed on a CPU and how efficiently the CPU's caches are utilized. perf stat可帮助您了解最终在CPU上执行的指令数量以及CPU缓存的利用效率。 This allows for advanced-level tuning of matrix operations. 这允许对矩阵运算进行高级调整。

heapy can track all of the objects inside Python's memory. heapy可以跟踪Python内存中的所有对象。 This is great for hunting down strange memory leaks. 这对于解决奇怪的内存泄漏非常有用。 If you're working with long-running systems, then dowser will interest you: it allows you to introspect live objects in a long-running process via a web browser interface. 如果您正在使用长时间运行的系统,则dowser会引起您的兴趣:它使您可以通过Web浏览器界面在长时间运行的过程中对活动对象进行自省。

To help you understand why your RAM usage is high, check out memory_profiler. 为了帮助您了解为什么RAM使用率很高,请查看memory_profiler. It is particularly useful for tracking RAM usage over time on a labeled chart, so you can explain to colleagues (or yourself) why certain functions use more RAM than expected. 这对于在带标签的图表上随时间推移跟踪RAM使用情况特别有用,因此您可以向同事(或您自己)解释为什么某些功能使用的RAM比预期更多。

Example: Defining a decorator to automate timing measurements 示例:定义装饰器以自动进行时序测量

from functools import wraps

def timefn(fn):
    @wraps(fn)
    def measure_time(*args, **kwargs):
        t1 = time.time()
        result = fn(*args, **kwargs)
        t2 = time.time()
        print ("@timefn:" + fn.func_name + " took " + str(t2 - t1) + " seconds")
        return result
    return measure_time

@timefn
def your_func(var1, var2):
    ...

For more information, I suggest reading High performance Python (Micha Gorelick; Ian Ozsvald) from which the above was sourced. 有关更多信息,建议阅读以上内容的高性能Python (Micha Gorelick; Ian Ozsvald)。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM