[英]Optimize the python function with numpy without using the for loop
[英]Optimize code for step function using only NumPy
我正在尝试仅使用 NumPy 函数(或者可能是列表推导式)优化以下代码中的函数“pw”。
from time import time
import numpy as np
def pw(x, udata):
"""
Creates the step function
| 1, if d0 <= x < d1
| 2, if d1 <= x < d2
pw(x,data) = ...
| N, if d(N-1) <= x < dN
| 0, otherwise
where di is the ith element in data.
INPUT: x -- interval which the step function is defined over
data -- an ordered set of data (without repetitions)
OUTPUT: pw_func -- an array of size x.shape[0]
"""
vals = np.arange(1,udata.shape[0]+1).reshape(udata.shape[0],1)
pw_func = np.sum(np.where(np.greater_equal(x,udata)*np.less(x,np.roll(udata,-1)),vals,0),axis=0)
return pw_func
N = 50000
x = np.linspace(0,10,N)
data = [1,3,4,5,5,7]
udata = np.unique(data)
ti = time()
pw(x,udata)
tf = time()
print(tf - ti)
import cProfile
cProfile.run('pw(x,udata)')
cProfile.run 告诉我大部分开销来自 np.where(大约 1 毫秒),但如果可能的话,我想创建更快的代码。 似乎按行和按列执行操作会有所不同,除非我弄错了,但我想我已经考虑到了。 我知道有时列表理解会更快,但我想不出比我正在做的更快的方法。
Searchsorted 似乎产生更好的性能,但 1 ms 仍然保留在我的计算机上:
(modified)
def pw(xx, uu):
"""
Creates the step function
| 1, if d0 <= x < d1
| 2, if d1 <= x < d2
pw(x,data) = ...
| N, if d(N-1) <= x < dN
| 0, otherwise
where di is the ith element in data.
INPUT: x -- interval which the step function is defined over
data -- an ordered set of data (without repetitions)
OUTPUT: pw_func -- an array of size x.shape[0]
"""
inds = np.searchsorted(uu, xx, side='right')
vals = np.arange(1,uu.shape[0]+1)
pw_func = vals[inds[inds != uu.shape[0]]]
num_mins = np.sum(xx < np.min(uu))
num_maxs = np.sum(xx > np.max(uu))
pw_func = np.concatenate((np.zeros(num_mins), pw_func, np.zeros(xx.shape[0]-pw_func.shape[0]-num_mins)))
return pw_func
这个使用分段的答案似乎非常接近,但这是在标量 x0 和 x1 上。 我将如何在阵列上做到这一点? 它会更有效率吗?
可以理解,x 可能相当大,但我正在尝试对其进行压力测试。
我仍在学习,所以一些可以帮助我的提示或技巧会很棒。
编辑
第二个函数似乎有错误,因为第二个函数的结果数组与第一个不匹配(我相信它可以工作):
N1 = pw1(x,udata.reshape(udata.shape[0],1)).shape[0]
N2 = np.sum(pw1(x,udata.reshape(udata.shape[0],1)) == pw2(x,udata))
print(N1 - N2)
产量
15000
不相同的数据点。 所以似乎我不知道如何使用'searchsorted'。
编辑 2
实际上我修复了它:
pw_func = vals[inds[inds != uu.shape[0]]]
改为
pw_func = vals[inds[inds[(inds != uu.shape[0])*(inds != 0)]-1]]
所以至少结果数组匹配。 但问题仍然是是否有更有效的方法来做到这一点。
编辑 3
感谢天丽指出错误。 这个应该工作
pw_func = vals[inds[(inds != uu.shape[0])*(inds != 0)]-1]
也许一种更易读的呈现方式是
non_endpts = (inds != uu.shape[0])*(inds != 0) # only consider the points in between the min/max data values
shift_inds = inds[non_endpts]-1 # searchsorted side='right' includes the left end point and not right end point so a shift is needed
pw_func = vals[shift_inds]
我想我迷失在所有这些括号中! 我想这就是可读性的重要性。
一个非常抽象但有趣的问题! 谢谢你招待我,我玩得很开心:)
ps 我不确定你的pw2
我无法得到与pw1
相同的输出。
供参考原始pw
:
def pw1(x, udata):
vals = np.arange(1,udata.shape[0]+1).reshape(udata.shape[0],1)
pw_func = np.sum(np.where(np.greater_equal(x,udata)*np.less(x,np.roll(udata,-1)),vals,0),axis=0)
return pw_func
def pw2(xx, uu):
inds = np.searchsorted(uu, xx, side='right')
vals = np.arange(1,uu.shape[0]+1)
pw_func = vals[inds[inds[(inds != uu.shape[0])*(inds != 0)]-1]]
num_mins = np.sum(xx < np.min(uu))
num_maxs = np.sum(xx > np.max(uu))
pw_func = np.concatenate((np.zeros(num_mins), pw_func, np.zeros(xx.shape[0]-pw_func.shape[0]-num_mins)))
return pw_func
我的第一次尝试是利用numpy
的大量广播操作:
def pw3(x, udata):
# the None slice is to create new axis
step_bool = x >= udata[None,:].T
# we exploit the fact that bools are integer value of 1s
# skipping the last value in "data"
step_vals = np.sum(step_bool[:-1], axis=0)
# for the step_bool that we skipped from previous step (last index)
# we set it to zerp so that we can negate the step_vals once we reached
# the last value in "data"
step_vals[step_bool[-1]] = 0
return step_vals
从您的searchsorted
中查看搜索pw2
我有了一种新方法,可以以更高的性能利用它:
def pw4(x, udata):
inds = np.searchsorted(udata, x, side='right')
# fix-ups the last data if x is already out of range of data[-1]
if x[-1] > udata[-1]:
inds[inds == inds[-1]] = 0
return inds
情节与:
plt.plot(pw1(x,udata.reshape(udata.shape[0],1)), label='pw1')
plt.plot(pw2(x,udata), label='pw2')
plt.plot(pw3(x,udata), label='pw3')
plt.plot(pw4(x,udata), label='pw4')
data = [1,3,4,5,5,7]
:
data = [1,3,4,5,5,7,11]
pw1
, pw3
, pw4
都是一样的
print(np.all(pw1(x,udata.reshape(udata.shape[0],1)) == pw3(x,udata)))
>>> True
print(np.all(pw1(x,udata.reshape(udata.shape[0],1)) == pw4(x,udata)))
>>> True
性能: ( timeit
默认运行3次,平均number=N
次)
print(timeit.Timer('pw1(x,udata.reshape(udata.shape[0],1))', "from __main__ import pw1, x, udata").repeat(number=1000))
>>> [3.1938983199979702, 1.6096494779994828, 1.962694135003403]
print(timeit.Timer('pw2(x,udata)', "from __main__ import pw2, x, udata").repeat(number=1000))
>>> [0.6884554479984217, 0.6075002400029916, 0.7799002879983163]
print(timeit.Timer('pw3(x,udata)', "from __main__ import pw3, x, udata").repeat(number=1000))
>>> [0.7369808239964186, 0.7557657590004965, 0.8088172269999632]
print(timeit.Timer('pw4(x,udata)', "from __main__ import pw4, x, udata").repeat(number=1000))
>>> [0.20514375300263055, 0.20203858999957447, 0.19906871100101853]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.