简体   繁体   English

仅使用 NumPy 优化步进函数的代码

[英]Optimize code for step function using only NumPy

I'm trying to optimize the function 'pw' in the following code using only NumPy functions (or perhaps list comprehensions).我正在尝试仅使用 NumPy 函数(或者可能是列表推导式)优化以下代码中的函数“pw”。

from time import time
import numpy as np

def pw(x, udata):
    """
    Creates the step function
                 | 1,  if d0 <= x < d1
                 | 2,  if d1 <= x < d2
    pw(x,data) = ...
                 | N, if d(N-1) <= x < dN
                 | 0, otherwise
    where di is the ith element in data.
    INPUT:      x   --  interval which the step function is defined over
              data  --  an ordered set of data (without repetitions)
    OUTPUT: pw_func --  an array of size x.shape[0]
    """
    vals = np.arange(1,udata.shape[0]+1).reshape(udata.shape[0],1)
    pw_func = np.sum(np.where(np.greater_equal(x,udata)*np.less(x,np.roll(udata,-1)),vals,0),axis=0)
    return pw_func


N = 50000
x = np.linspace(0,10,N)
data = [1,3,4,5,5,7]
udata = np.unique(data)

ti = time()
pw(x,udata)
tf = time()
print(tf - ti)

import cProfile
cProfile.run('pw(x,udata)')

The cProfile.run is telling me that most of the overhead is coming from np.where (about 1 ms) but I'd like to create faster code if possible. cProfile.run 告诉我大部分开销来自 np.where(大约 1 毫秒),但如果可能的话,我想创建更快的代码。 It seems that performing the operations row-wise versus column-wise makes some difference, unless I'm mistaken, but I think I've accounted for it.似乎按行和按列执行操作会有所不同,除非我弄错了,但我想我已经考虑到了。 I know that sometimes list comprehensions can be faster but I couldn't figure out a faster way than what I'm doing using it.我知道有时列表理解会更快,但我想不出比我正在做的更快的方法。

Searchsorted seems to yield better performance but that 1 ms still remains on my computer: Searchsorted 似乎产生更好的性能,但 1 ms 仍然保留在我的计算机上:

(modified)
def pw(xx, uu):
    """
    Creates the step function
                 | 1,  if d0 <= x < d1
                 | 2,  if d1 <= x < d2
    pw(x,data) = ...
                 | N, if d(N-1) <= x < dN
                 | 0, otherwise
    where di is the ith element in data.
    INPUT:      x   --  interval which the step function is defined over
              data  --  an ordered set of data (without repetitions)
    OUTPUT: pw_func --  an array of size x.shape[0]
    """
    inds = np.searchsorted(uu, xx, side='right')
    vals = np.arange(1,uu.shape[0]+1)
    pw_func = vals[inds[inds != uu.shape[0]]]
    num_mins = np.sum(xx < np.min(uu))
    num_maxs = np.sum(xx > np.max(uu))

    pw_func = np.concatenate((np.zeros(num_mins), pw_func, np.zeros(xx.shape[0]-pw_func.shape[0]-num_mins)))
    return pw_func

This answer using piecewise seems pretty close, but that's on a scalar x0 and x1.这个使用分段的答案似乎非常接近,但这是在标量 x0 和 x1 上。 How would I do it on arrays?我将如何在阵列上做到这一点? And would it be more efficient?它会更有效率吗?

Understandably, x may be pretty big but I'm trying to put it through a stress test.可以理解,x 可能相当大,但我正在尝试对其进行压力测试。

I am still learning though so some hints or tricks that can help me out would be great.我仍在学习,所以一些可以帮助我的提示或技巧会很棒。

EDIT编辑

There seems to be a mistake in the second function since the resulting array from the second function doesn't match the first one (which I'm confident that it works):第二个函数似乎有错误,因为第二个函数的结果数组与第一个不匹配(我相信它可以工作):

N1 = pw1(x,udata.reshape(udata.shape[0],1)).shape[0]
N2 = np.sum(pw1(x,udata.reshape(udata.shape[0],1)) == pw2(x,udata))
print(N1 - N2)

yields产量

15000

data points that are not the same.相同的数据点。 So it seems that I don't know how to use 'searchsorted'.所以似乎我不知道如何使用'searchsorted'。

EDIT 2编辑 2

Actually I fixed it:实际上我修复了它:

pw_func = vals[inds[inds != uu.shape[0]]]

was changed to改为

pw_func = vals[inds[inds[(inds != uu.shape[0])*(inds != 0)]-1]]

so at least the resulting arrays match.所以至少结果数组匹配。 But the question still remains on whether there's a more efficient way of going about doing this.但问题仍然是是否有更有效的方法来做到这一点。

EDIT 3编辑 3

Thanks Tin Lai for pointing out the mistake.感谢天丽指出错误。 This one should work这个应该工作

pw_func = vals[inds[(inds != uu.shape[0])*(inds != 0)]-1]

Maybe a more readable way of presenting it would be也许一种更易读的呈现方式是

non_endpts = (inds != uu.shape[0])*(inds != 0) # only consider the points in between the min/max data values
shift_inds = inds[non_endpts]-1       # searchsorted side='right' includes the left end point and not right end point so a shift is needed
pw_func = vals[shift_inds]

I think I got lost in all those brackets!我想我迷失在所有这些括号中! I guess that's the importance of readability.我想这就是可读性的重要性。

A very abstract yet interesting problem!一个非常抽象但有趣的问题! Thanks for entertaining me, I had fun :)谢谢你招待我,我玩得很开心:)

ps I'm not sure about your pw2 I wasn't able to get it output the same as pw1 . ps 我不确定你的pw2我无法得到与pw1相同的输出。

For reference the original pw s:供参考原始pw

def pw1(x, udata):
    vals = np.arange(1,udata.shape[0]+1).reshape(udata.shape[0],1)
    pw_func = np.sum(np.where(np.greater_equal(x,udata)*np.less(x,np.roll(udata,-1)),vals,0),axis=0)
    return pw_func

def pw2(xx, uu):
    inds = np.searchsorted(uu, xx, side='right')
    vals = np.arange(1,uu.shape[0]+1)
    pw_func = vals[inds[inds[(inds != uu.shape[0])*(inds != 0)]-1]]
    num_mins = np.sum(xx < np.min(uu))
    num_maxs = np.sum(xx > np.max(uu))

    pw_func = np.concatenate((np.zeros(num_mins), pw_func, np.zeros(xx.shape[0]-pw_func.shape[0]-num_mins)))
    return pw_func

My first attempt was utilising a lot of boardcasting operation from numpy :我的第一次尝试是利用numpy的大量广播操作:

def pw3(x, udata):
    # the None slice is to create new axis
    step_bool = x >= udata[None,:].T

    # we exploit the fact that bools are integer value of 1s
    # skipping the last value in "data"
    step_vals = np.sum(step_bool[:-1], axis=0)

    # for the step_bool that we skipped from previous step (last index)
    # we set it to zerp so that we can negate the step_vals once we reached
    # the last value in "data"
    step_vals[step_bool[-1]] = 0

    return step_vals

After looking at the searchsorted from your pw2 I had a new approach that utilise it with much higher performance:从您的searchsorted中查看搜索pw2我有了一种新方法,可以以更高的性能利用它:

def pw4(x, udata):
    inds = np.searchsorted(udata, x, side='right')

    # fix-ups the last data if x is already out of range of data[-1]
    if x[-1] > udata[-1]:
        inds[inds == inds[-1]] = 0

    return inds

Plots with:情节与:

plt.plot(pw1(x,udata.reshape(udata.shape[0],1)), label='pw1')
plt.plot(pw2(x,udata), label='pw2')
plt.plot(pw3(x,udata), label='pw3')
plt.plot(pw4(x,udata), label='pw4')

with data = [1,3,4,5,5,7] : data = [1,3,4,5,5,7]

在此处输入图片说明

with data = [1,3,4,5,5,7,11] data = [1,3,4,5,5,7,11]

在此处输入图片说明

pw1 , pw3 , pw4 are all identical pw1 , pw3 , pw4都是一样的

print(np.all(pw1(x,udata.reshape(udata.shape[0],1)) == pw3(x,udata)))
>>> True
print(np.all(pw1(x,udata.reshape(udata.shape[0],1)) == pw4(x,udata)))
>>> True

Performance: ( timeit by default runs 3 times, average of number=N of times)性能:timeit默认运行3次,平均number=N次)

print(timeit.Timer('pw1(x,udata.reshape(udata.shape[0],1))', "from __main__ import pw1, x, udata").repeat(number=1000))
>>> [3.1938983199979702, 1.6096494779994828, 1.962694135003403]
print(timeit.Timer('pw2(x,udata)', "from __main__ import pw2, x, udata").repeat(number=1000))
>>> [0.6884554479984217, 0.6075002400029916, 0.7799002879983163]
print(timeit.Timer('pw3(x,udata)', "from __main__ import pw3, x, udata").repeat(number=1000))
>>> [0.7369808239964186, 0.7557657590004965, 0.8088172269999632]
print(timeit.Timer('pw4(x,udata)', "from __main__ import pw4, x, udata").repeat(number=1000))
>>> [0.20514375300263055, 0.20203858999957447, 0.19906871100101853]

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM