简体   繁体   English

为什么numpy.where比替代品快得多

[英]Why numpy.where is much faster than alternatives

im trying to speedup the following code: 我试图加快以下代码:

import time
import numpy as np
np.random.seed(10)
b=np.random.rand(10000,1000)
def f(a=1):
    tott=0
    for _ in range(a):
        q=np.array(b)
        t1 = time.time()
        for i in range(len(q)):
            for j in range(len(q[0])):
                if q[i][j]>0.5:
                    q[i][j]=1
                else:
                    q[i][j]=-1
        t2=time.time()
        tott+=t2-t1
    print(tott/a)

As you can see, mainly func is about iterating in double cycle. 如您所见,主要是func是关于双循环迭代。 So, i've tried to use np.nditer , np.vectorize and map instead of it. 所以,我试图使用np.nditernp.vectorizemap而不是它。 If gives some speedup (like 4-5 times except np.nditer ), but! 如果给出一些加速(比如np.nditer除了4-5次),但是! with np.where(q>0.5,1,-1) speedup is almost 100x. np.where(q>0.5,1,-1)加速几乎是100x。 How can i iterate over numpy arrays as fast as np.where does it? 我怎样才能像np.where那样快速迭代numpy数组呢? And why is it so much faster? 为什么它这么快?

It's because the core of numpy is implemented in C. You're basically comparing the speed of C with Python. 这是因为numpy的核心是用C实现的。你基本上是将C的速度与Python进行比较。

If you want to use the speed advantage of numpy, you should make as few calls as possible in your Python code. 如果你想使用numpy的速度优势,你应该在Python代码中尽可能少地进行调用。 If you use a Python-loop, you have already lost, even if you use numpy functions in that loop only. 如果使用Python循环,即使只在该循环中使用numpy函数,也已经丢失了。 Use higher-level functions provided by numpy (that's why they ship so many special functions). 使用numpy提供的更高级别的功能(这就是他们发布这么多特殊功能的原因)。 Internally, it will use a much more efficient (C-)loop 在内部,它将使用更高效的(C-)循环

You can implement a function in C (with loops) yourself and call that from Python. 您可以自己在C(带循环)中实现一个函数,并从Python中调用它。 That should give comparable speeds. 这应该提供相当的速度。

To answer this question, you can gain the same speed (100x acceleration) by using the numba library: 要回答这个问题,您可以使用numba库获得相同的速度(100倍加速度):

from numba import njit

def f(b):
    q = np.zeros_like(b)

    for i in range(b.shape[0]):
        for j in range(b.shape[1]):
            if q[i][j] > 0.5:
                q[i][j] = 1
            else:
                q[i][j] = -1

    return q

@njit
def f_jit(b):
    q = np.zeros_like(b)

    for i in range(b.shape[0]):
        for j in range(b.shape[1]):
            if q[i][j] > 0.5:
                q[i][j] = 1
            else:
                q[i][j] = -1

    return q

Compare the speed: 比较速度:

Plain Python 纯Python

%timeit f(b)
592 ms ± 5.72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Numba (just-in-time compiled using LLVM ~ C speed) Numba(使用LLVM~C速度进行即时编译)

%timeit f_jit(b)
5.97 ms ± 105 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM