繁体   English   中英

Numba 使 python 代码在一个简单的 for 循环中运行得更慢(不使用 numpy)

[英]Numba makes python code run slower in a simple for-loop (Not using numpy)

我正在尝试使用 Numba 并尝试运行此代码(映射一个大列表):

from numba import njit, jit
from datetime import datetime


big_list = [(i, i + 10000) for i in range(1, 100000000)]

#Just a number of arithmetic operations (Using Numba).
@njit(cache=True)
def just_calc_jit(row):
    exp_1 = row[1] / row[0]
    exp_2 = (row[0] + 10000) / row[1]
    exp_3 = (exp_2 - row[0]) / exp_1
    exp_3 *= exp_3
    return exp_3

# Same function without Numba.
def just_calc(row):
    exp_1 = row[1] / row[0]
    exp_2 = (row[0] + 10000) / row[1]
    exp_3 = (exp_2 - row[0]) / exp_1
    exp_3 *= exp_3
    return exp_3

# Prints execution times (with and without Numba) 5 times for every function.
for i in range(5):
    start = datetime.now()
    result = list(map(just_calc, big_list))
    execution_time = datetime.now() - start
    print("execution time:", execution_time)

    start = datetime.now()
    result = list(map(just_calc, big_list))
    execution_time = datetime.now() - start
    print("execution time jit:", execution_time)

这是脚本的输出(您可以看到每次使用和不使用 Numba 5 次的执行时间):

execution time: 0:00:17.643550
execution time jit: 0:00:19.780514
execution time: 0:00:19.072673
execution time jit: 0:00:18.961395
execution time: 0:00:20.567786
execution time jit: 0:00:20.119370
execution time: 0:00:21.254276
execution time jit: 0:00:20.034304
execution time: 0:00:20.219750
execution time jit: 0:00:19.237941

我错过了什么/做错了什么?

我在您的代码中更改了两件事,并在运行时使用numba获得了更好的结果:(numba函数的第一次运行中,我们得到了一个不好的结果。)

  1. 循环遍历函数中的big_list行。
  2. big_list作为numpy.asarray输入到函数中。
from datetime import datetime
import numba as nb
import numpy as np


big_list = [(i, i + 10_000) for i in range(1, 100_000)]

#Just a number of arithmetic operations (Using Numba).
@nb.njit(parallel = True)
def just_calc_jit(arr):
    num_row = len(arr)
    res = np.empty((num_row))
    for i in nb.prange(num_row):
        exp_1 = arr[i][1] / arr[i][0]
        exp_2 = (arr[i][0] + 10000) / arr[i][1]
        exp_3 = (exp_2 - arr[i][0]) / exp_1
        exp_3 *= exp_3
        res[i] = exp_3
    return res

# Same function without Numba.
def just_calc(arr):
    num_row = len(arr)
    res = np.empty((num_row))
    for i in range(num_row):
        exp_1 = arr[i][1] / arr[i][0]
        exp_2 = (arr[i][0] + 10000) / arr[i][1]
        exp_3 = (exp_2 - arr[i][0]) / exp_1
        exp_3 *= exp_3
        res[i] = exp_3
    return res

# Prints execution times (with and without Numba) 5 times for every function.
for i in range(5):
    start = datetime.now()
    result = just_calc(np.asarray(big_list))
    execution_time = datetime.now() - start
    print("execution time:", execution_time)

    start = datetime.now()
    result = just_calc_jit(np.asarray(big_list))
    execution_time = datetime.now() - start
    print("execution time jit:", execution_time)

输出: colab上的基准测试)

execution time: 0:00:00.323675
execution time jit: 0:00:00.639346
execution time: 0:00:00.237574
execution time jit: 0:00:00.046685
execution time: 0:00:00.222264
execution time jit: 0:00:00.048550
execution time: 0:00:00.223323
execution time jit: 0:00:00.049903
execution time: 0:00:00.222570
execution time jit: 0:00:00.049623

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM