繁体   English   中英

仅使用 NumPy 优化步进函数的代码

[英]Optimize code for step function using only NumPy

我正在尝试仅使用 NumPy 函数(或者可能是列表推导式)优化以下代码中的函数“pw”。

from time import time
import numpy as np

def pw(x, udata):
    """
    Creates the step function
                 | 1,  if d0 <= x < d1
                 | 2,  if d1 <= x < d2
    pw(x,data) = ...
                 | N, if d(N-1) <= x < dN
                 | 0, otherwise
    where di is the ith element in data.
    INPUT:      x   --  interval which the step function is defined over
              data  --  an ordered set of data (without repetitions)
    OUTPUT: pw_func --  an array of size x.shape[0]
    """
    vals = np.arange(1,udata.shape[0]+1).reshape(udata.shape[0],1)
    pw_func = np.sum(np.where(np.greater_equal(x,udata)*np.less(x,np.roll(udata,-1)),vals,0),axis=0)
    return pw_func


N = 50000
x = np.linspace(0,10,N)
data = [1,3,4,5,5,7]
udata = np.unique(data)

ti = time()
pw(x,udata)
tf = time()
print(tf - ti)

import cProfile
cProfile.run('pw(x,udata)')

cProfile.run 告诉我大部分开销来自 np.where(大约 1 毫秒),但如果可能的话,我想创建更快的代码。 似乎按行和按列执行操作会有所不同,除非我弄错了,但我想我已经考虑到了。 我知道有时列表理解会更快,但我想不出比我正在做的更快的方法。

Searchsorted 似乎产生更好的性能,但 1 ms 仍然保留在我的计算机上:

(modified)
def pw(xx, uu):
    """
    Creates the step function
                 | 1,  if d0 <= x < d1
                 | 2,  if d1 <= x < d2
    pw(x,data) = ...
                 | N, if d(N-1) <= x < dN
                 | 0, otherwise
    where di is the ith element in data.
    INPUT:      x   --  interval which the step function is defined over
              data  --  an ordered set of data (without repetitions)
    OUTPUT: pw_func --  an array of size x.shape[0]
    """
    inds = np.searchsorted(uu, xx, side='right')
    vals = np.arange(1,uu.shape[0]+1)
    pw_func = vals[inds[inds != uu.shape[0]]]
    num_mins = np.sum(xx < np.min(uu))
    num_maxs = np.sum(xx > np.max(uu))

    pw_func = np.concatenate((np.zeros(num_mins), pw_func, np.zeros(xx.shape[0]-pw_func.shape[0]-num_mins)))
    return pw_func

这个使用分段的答案似乎非常接近,但这是在标量 x0 和 x1 上。 我将如何在阵列上做到这一点? 它会更有效率吗?

可以理解,x 可能相当大,但我正在尝试对其进行压力测试。

我仍在学习,所以一些可以帮助我的提示或技巧会很棒。

编辑

第二个函数似乎有错误,因为第二个函数的结果数组与第一个不匹配(我相信它可以工作):

N1 = pw1(x,udata.reshape(udata.shape[0],1)).shape[0]
N2 = np.sum(pw1(x,udata.reshape(udata.shape[0],1)) == pw2(x,udata))
print(N1 - N2)

产量

15000

相同的数据点。 所以似乎我不知道如何使用'searchsorted'。

编辑 2

实际上我修复了它:

pw_func = vals[inds[inds != uu.shape[0]]]

改为

pw_func = vals[inds[inds[(inds != uu.shape[0])*(inds != 0)]-1]]

所以至少结果数组匹配。 但问题仍然是是否有更有效的方法来做到这一点。

编辑 3

感谢天丽指出错误。 这个应该工作

pw_func = vals[inds[(inds != uu.shape[0])*(inds != 0)]-1]

也许一种更易读的呈现方式是

non_endpts = (inds != uu.shape[0])*(inds != 0) # only consider the points in between the min/max data values
shift_inds = inds[non_endpts]-1       # searchsorted side='right' includes the left end point and not right end point so a shift is needed
pw_func = vals[shift_inds]

我想我迷失在所有这些括号中! 我想这就是可读性的重要性。

一个非常抽象但有趣的问题! 谢谢你招待我,我玩得很开心:)

ps 我不确定你的pw2我无法得到与pw1相同的输出。

供参考原始pw

def pw1(x, udata):
    vals = np.arange(1,udata.shape[0]+1).reshape(udata.shape[0],1)
    pw_func = np.sum(np.where(np.greater_equal(x,udata)*np.less(x,np.roll(udata,-1)),vals,0),axis=0)
    return pw_func

def pw2(xx, uu):
    inds = np.searchsorted(uu, xx, side='right')
    vals = np.arange(1,uu.shape[0]+1)
    pw_func = vals[inds[inds[(inds != uu.shape[0])*(inds != 0)]-1]]
    num_mins = np.sum(xx < np.min(uu))
    num_maxs = np.sum(xx > np.max(uu))

    pw_func = np.concatenate((np.zeros(num_mins), pw_func, np.zeros(xx.shape[0]-pw_func.shape[0]-num_mins)))
    return pw_func

我的第一次尝试是利用numpy的大量广播操作:

def pw3(x, udata):
    # the None slice is to create new axis
    step_bool = x >= udata[None,:].T

    # we exploit the fact that bools are integer value of 1s
    # skipping the last value in "data"
    step_vals = np.sum(step_bool[:-1], axis=0)

    # for the step_bool that we skipped from previous step (last index)
    # we set it to zerp so that we can negate the step_vals once we reached
    # the last value in "data"
    step_vals[step_bool[-1]] = 0

    return step_vals

从您的searchsorted中查看搜索pw2我有了一种新方法,可以以更高的性能利用它:

def pw4(x, udata):
    inds = np.searchsorted(udata, x, side='right')

    # fix-ups the last data if x is already out of range of data[-1]
    if x[-1] > udata[-1]:
        inds[inds == inds[-1]] = 0

    return inds

情节与:

plt.plot(pw1(x,udata.reshape(udata.shape[0],1)), label='pw1')
plt.plot(pw2(x,udata), label='pw2')
plt.plot(pw3(x,udata), label='pw3')
plt.plot(pw4(x,udata), label='pw4')

data = [1,3,4,5,5,7]

在此处输入图片说明

data = [1,3,4,5,5,7,11]

在此处输入图片说明

pw1 , pw3 , pw4都是一样的

print(np.all(pw1(x,udata.reshape(udata.shape[0],1)) == pw3(x,udata)))
>>> True
print(np.all(pw1(x,udata.reshape(udata.shape[0],1)) == pw4(x,udata)))
>>> True

性能:timeit默认运行3次,平均number=N次)

print(timeit.Timer('pw1(x,udata.reshape(udata.shape[0],1))', "from __main__ import pw1, x, udata").repeat(number=1000))
>>> [3.1938983199979702, 1.6096494779994828, 1.962694135003403]
print(timeit.Timer('pw2(x,udata)', "from __main__ import pw2, x, udata").repeat(number=1000))
>>> [0.6884554479984217, 0.6075002400029916, 0.7799002879983163]
print(timeit.Timer('pw3(x,udata)', "from __main__ import pw3, x, udata").repeat(number=1000))
>>> [0.7369808239964186, 0.7557657590004965, 0.8088172269999632]
print(timeit.Timer('pw4(x,udata)', "from __main__ import pw4, x, udata").repeat(number=1000))
>>> [0.20514375300263055, 0.20203858999957447, 0.19906871100101853]

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM