[英]Improve performance of python script using numba jit
I am running a sample python simulation to predict a weighted & regular dice.我正在运行一个示例 python 模拟来预测加权和常规骰子。 I would like to use numba to help speed up my script but I receive an error:
我想使用 numba 来帮助加快我的脚本,但我收到一个错误:
<timed exec>:6: NumbaWarning:
Compilation is falling back to object mode WITH looplifting enabled because Function "roll" failed type inference due to: Untyped global name 'sum': cannot determine Numba type of <class 'builtin_function_or_method'>
File "<timed exec>", line 9:
<source missing, REPL/exec in use?>
Here is my original code: Is there another type of numba expression I can use instead?这是我的原始代码:我可以使用另一种类型的 numba 表达式吗? Right now I'm testing using input of 2500 rolls;
现在我正在使用 2500 卷的输入进行测试; want to get this down to 4 seconds (it's currently at 8.5 seconds).
想把这个时间缩短到 4 秒(目前是 8.5 秒)。
%%time
from numba import jit
import random
import matplotlib.pyplot as plt
import numpy
@jit
def roll(sides, bias_list):
assert len(bias_list) == sides, "Enter correct number of dice sides"
number = random.uniform(0, sum(bias_list))
current = 0
for i, bias in enumerate(bias_list):
current += bias
if number <= current:
return i + 1
no_of_rolls = 2500
weighted_die = {}
normal_die = {}
#weighted die
for i in range(no_of_rolls):
weighted_die[i+1]=roll(6,(0.15, 0.15, 0.15, 0.15, 0.15, 0.25))
#regular die
for i in range(no_of_rolls):
normal_die[i+1]=roll(6,(0.167, 0.167, 0.167, 0.167, 0.167, 0.165))
plt.bar(*zip(*weighted_die.items()))
plt.show()
plt.bar(*zip(*normal_die.items()))
plt.show()
Using Random Choices使用随机选择
Refactored Code重构代码
import random
import matplotlib.pyplot as plt
no_of_rolls = 2500
# weights
normal_weights = (0.167, 0.167, 0.167, 0.167, 0.167, 0.165)
bias_weights = (0.15, 0.15, 0.15, 0.15, 0.15, 0.25)
# Replaced roll function with random.choices
# Reference: https://www.w3schools.com/python/ref_random_choices.asp
bias_rolls = random.choices(range(1, 7), weights = bias_weights, k = no_of_rolls)
normal_rolls = random.choices(range(1, 7), weights = normal_weights, k = no_of_rolls)
# Create dictionaries with same structure as posted code
weighted_die = dict(zip(range(no_of_rolls), bias_rolls))
normal_die = dict(zip(range(no_of_rolls), normal_rolls))
# Use posted plotting calls
plt.bar(*zip(*weighted_die.items()))
plt.show()
plt.bar(*zip(*normal_die.items()))
plt.show()
Performance表现
*Not including plotting.*
Original code: ~6 ms
Revised code: ~2 ms
(3x improvement, but not sure why the post mentions 8 seconds to run)
You can accelerate it using guvectorize您可以使用 guvectorize 加速它
%%time
from numba import guvectorize
import matplotlib.pyplot as plt
import numpy as np
import random
sides = 6
bias_list = (0.15, 0.15, 0.15, 0.15, 0.15, 0.25)
@guvectorize(["f8[:,:], uint8[:]"], "(n, k) -> (n)", nopython=True)
def roll(biases, side):
for i in range(biases.shape[0]):
number = random.uniform(0, np.sum(biases[i,:]))
current = 0
for j, bias in enumerate(biases[i,:]):
current += bias
if number <= current:
side[i] = j + 1
break
no_of_rolls = 2500
biases = np.zeros((no_of_rolls,len(bias_list)))
biases[:,] = np.array(bias_list)
normal_die = roll(biases)
print(normal_die)
This took ~200 ms on my PC, while Your code about 6 sec.这在我的 PC 上花费了大约 200 毫秒,而您的代码大约需要 6 秒。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.