[英]Numba makes python code run slower in a simple for-loop (Not using numpy)
我正在尝试使用 Numba 并尝试运行此代码(映射一个大列表):
from numba import njit, jit
from datetime import datetime
big_list = [(i, i + 10000) for i in range(1, 100000000)]
#Just a number of arithmetic operations (Using Numba).
@njit(cache=True)
def just_calc_jit(row):
exp_1 = row[1] / row[0]
exp_2 = (row[0] + 10000) / row[1]
exp_3 = (exp_2 - row[0]) / exp_1
exp_3 *= exp_3
return exp_3
# Same function without Numba.
def just_calc(row):
exp_1 = row[1] / row[0]
exp_2 = (row[0] + 10000) / row[1]
exp_3 = (exp_2 - row[0]) / exp_1
exp_3 *= exp_3
return exp_3
# Prints execution times (with and without Numba) 5 times for every function.
for i in range(5):
start = datetime.now()
result = list(map(just_calc, big_list))
execution_time = datetime.now() - start
print("execution time:", execution_time)
start = datetime.now()
result = list(map(just_calc, big_list))
execution_time = datetime.now() - start
print("execution time jit:", execution_time)
这是脚本的输出(您可以看到每次使用和不使用 Numba 5 次的执行时间):
execution time: 0:00:17.643550
execution time jit: 0:00:19.780514
execution time: 0:00:19.072673
execution time jit: 0:00:18.961395
execution time: 0:00:20.567786
execution time jit: 0:00:20.119370
execution time: 0:00:21.254276
execution time jit: 0:00:20.034304
execution time: 0:00:20.219750
execution time jit: 0:00:19.237941
我错过了什么/做错了什么?
我在您的代码中更改了两件事,并在运行时使用numba
获得了更好的结果:(在numba
函数的第一次运行中,我们得到了一个不好的结果。)
big_list
行。big_list
作为numpy.asarray
输入到函数中。from datetime import datetime
import numba as nb
import numpy as np
big_list = [(i, i + 10_000) for i in range(1, 100_000)]
#Just a number of arithmetic operations (Using Numba).
@nb.njit(parallel = True)
def just_calc_jit(arr):
num_row = len(arr)
res = np.empty((num_row))
for i in nb.prange(num_row):
exp_1 = arr[i][1] / arr[i][0]
exp_2 = (arr[i][0] + 10000) / arr[i][1]
exp_3 = (exp_2 - arr[i][0]) / exp_1
exp_3 *= exp_3
res[i] = exp_3
return res
# Same function without Numba.
def just_calc(arr):
num_row = len(arr)
res = np.empty((num_row))
for i in range(num_row):
exp_1 = arr[i][1] / arr[i][0]
exp_2 = (arr[i][0] + 10000) / arr[i][1]
exp_3 = (exp_2 - arr[i][0]) / exp_1
exp_3 *= exp_3
res[i] = exp_3
return res
# Prints execution times (with and without Numba) 5 times for every function.
for i in range(5):
start = datetime.now()
result = just_calc(np.asarray(big_list))
execution_time = datetime.now() - start
print("execution time:", execution_time)
start = datetime.now()
result = just_calc_jit(np.asarray(big_list))
execution_time = datetime.now() - start
print("execution time jit:", execution_time)
输出: ( colab上的基准测试)
execution time: 0:00:00.323675
execution time jit: 0:00:00.639346
execution time: 0:00:00.237574
execution time jit: 0:00:00.046685
execution time: 0:00:00.222264
execution time jit: 0:00:00.048550
execution time: 0:00:00.223323
execution time jit: 0:00:00.049903
execution time: 0:00:00.222570
execution time jit: 0:00:00.049623
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.