簡體   English   中英

僅使用 NumPy 優化步進函數的代碼

[英]Optimize code for step function using only NumPy

我正在嘗試僅使用 NumPy 函數(或者可能是列表推導式)優化以下代碼中的函數“pw”。

from time import time
import numpy as np

def pw(x, udata):
    """
    Creates the step function
                 | 1,  if d0 <= x < d1
                 | 2,  if d1 <= x < d2
    pw(x,data) = ...
                 | N, if d(N-1) <= x < dN
                 | 0, otherwise
    where di is the ith element in data.
    INPUT:      x   --  interval which the step function is defined over
              data  --  an ordered set of data (without repetitions)
    OUTPUT: pw_func --  an array of size x.shape[0]
    """
    vals = np.arange(1,udata.shape[0]+1).reshape(udata.shape[0],1)
    pw_func = np.sum(np.where(np.greater_equal(x,udata)*np.less(x,np.roll(udata,-1)),vals,0),axis=0)
    return pw_func


N = 50000
x = np.linspace(0,10,N)
data = [1,3,4,5,5,7]
udata = np.unique(data)

ti = time()
pw(x,udata)
tf = time()
print(tf - ti)

import cProfile
cProfile.run('pw(x,udata)')

cProfile.run 告訴我大部分開銷來自 np.where(大約 1 毫秒),但如果可能的話,我想創建更快的代碼。 似乎按行和按列執行操作會有所不同,除非我弄錯了,但我想我已經考慮到了。 我知道有時列表理解會更快,但我想不出比我正在做的更快的方法。

Searchsorted 似乎產生更好的性能,但 1 ms 仍然保留在我的計算機上:

(modified)
def pw(xx, uu):
    """
    Creates the step function
                 | 1,  if d0 <= x < d1
                 | 2,  if d1 <= x < d2
    pw(x,data) = ...
                 | N, if d(N-1) <= x < dN
                 | 0, otherwise
    where di is the ith element in data.
    INPUT:      x   --  interval which the step function is defined over
              data  --  an ordered set of data (without repetitions)
    OUTPUT: pw_func --  an array of size x.shape[0]
    """
    inds = np.searchsorted(uu, xx, side='right')
    vals = np.arange(1,uu.shape[0]+1)
    pw_func = vals[inds[inds != uu.shape[0]]]
    num_mins = np.sum(xx < np.min(uu))
    num_maxs = np.sum(xx > np.max(uu))

    pw_func = np.concatenate((np.zeros(num_mins), pw_func, np.zeros(xx.shape[0]-pw_func.shape[0]-num_mins)))
    return pw_func

這個使用分段的答案似乎非常接近,但這是在標量 x0 和 x1 上。 我將如何在陣列上做到這一點? 它會更有效率嗎?

可以理解,x 可能相當大,但我正在嘗試對其進行壓力測試。

我仍在學習,所以一些可以幫助我的提示或技巧會很棒。

編輯

第二個函數似乎有錯誤,因為第二個函數的結果數組與第一個不匹配(我相信它可以工作):

N1 = pw1(x,udata.reshape(udata.shape[0],1)).shape[0]
N2 = np.sum(pw1(x,udata.reshape(udata.shape[0],1)) == pw2(x,udata))
print(N1 - N2)

產量

15000

相同的數據點。 所以似乎我不知道如何使用'searchsorted'。

編輯 2

實際上我修復了它:

pw_func = vals[inds[inds != uu.shape[0]]]

改為

pw_func = vals[inds[inds[(inds != uu.shape[0])*(inds != 0)]-1]]

所以至少結果數組匹配。 但問題仍然是是否有更有效的方法來做到這一點。

編輯 3

感謝天麗指出錯誤。 這個應該工作

pw_func = vals[inds[(inds != uu.shape[0])*(inds != 0)]-1]

也許一種更易讀的呈現方式是

non_endpts = (inds != uu.shape[0])*(inds != 0) # only consider the points in between the min/max data values
shift_inds = inds[non_endpts]-1       # searchsorted side='right' includes the left end point and not right end point so a shift is needed
pw_func = vals[shift_inds]

我想我迷失在所有這些括號中! 我想這就是可讀性的重要性。

一個非常抽象但有趣的問題! 謝謝你招待我,我玩得很開心:)

ps 我不確定你的pw2我無法得到與pw1相同的輸出。

供參考原始pw

def pw1(x, udata):
    vals = np.arange(1,udata.shape[0]+1).reshape(udata.shape[0],1)
    pw_func = np.sum(np.where(np.greater_equal(x,udata)*np.less(x,np.roll(udata,-1)),vals,0),axis=0)
    return pw_func

def pw2(xx, uu):
    inds = np.searchsorted(uu, xx, side='right')
    vals = np.arange(1,uu.shape[0]+1)
    pw_func = vals[inds[inds[(inds != uu.shape[0])*(inds != 0)]-1]]
    num_mins = np.sum(xx < np.min(uu))
    num_maxs = np.sum(xx > np.max(uu))

    pw_func = np.concatenate((np.zeros(num_mins), pw_func, np.zeros(xx.shape[0]-pw_func.shape[0]-num_mins)))
    return pw_func

我的第一次嘗試是利用numpy的大量廣播操作:

def pw3(x, udata):
    # the None slice is to create new axis
    step_bool = x >= udata[None,:].T

    # we exploit the fact that bools are integer value of 1s
    # skipping the last value in "data"
    step_vals = np.sum(step_bool[:-1], axis=0)

    # for the step_bool that we skipped from previous step (last index)
    # we set it to zerp so that we can negate the step_vals once we reached
    # the last value in "data"
    step_vals[step_bool[-1]] = 0

    return step_vals

從您的searchsorted中查看搜索pw2我有了一種新方法,可以以更高的性能利用它:

def pw4(x, udata):
    inds = np.searchsorted(udata, x, side='right')

    # fix-ups the last data if x is already out of range of data[-1]
    if x[-1] > udata[-1]:
        inds[inds == inds[-1]] = 0

    return inds

情節與:

plt.plot(pw1(x,udata.reshape(udata.shape[0],1)), label='pw1')
plt.plot(pw2(x,udata), label='pw2')
plt.plot(pw3(x,udata), label='pw3')
plt.plot(pw4(x,udata), label='pw4')

data = [1,3,4,5,5,7]

在此處輸入圖片說明

data = [1,3,4,5,5,7,11]

在此處輸入圖片說明

pw1 , pw3 , pw4都是一樣的

print(np.all(pw1(x,udata.reshape(udata.shape[0],1)) == pw3(x,udata)))
>>> True
print(np.all(pw1(x,udata.reshape(udata.shape[0],1)) == pw4(x,udata)))
>>> True

性能:timeit默認運行3次,平均number=N次)

print(timeit.Timer('pw1(x,udata.reshape(udata.shape[0],1))', "from __main__ import pw1, x, udata").repeat(number=1000))
>>> [3.1938983199979702, 1.6096494779994828, 1.962694135003403]
print(timeit.Timer('pw2(x,udata)', "from __main__ import pw2, x, udata").repeat(number=1000))
>>> [0.6884554479984217, 0.6075002400029916, 0.7799002879983163]
print(timeit.Timer('pw3(x,udata)', "from __main__ import pw3, x, udata").repeat(number=1000))
>>> [0.7369808239964186, 0.7557657590004965, 0.8088172269999632]
print(timeit.Timer('pw4(x,udata)', "from __main__ import pw4, x, udata").repeat(number=1000))
>>> [0.20514375300263055, 0.20203858999957447, 0.19906871100101853]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM