简体   繁体   中英

Optimize code for step function using only NumPy

I'm trying to optimize the function 'pw' in the following code using only NumPy functions (or perhaps list comprehensions).

from time import time
import numpy as np

def pw(x, udata):
    """
    Creates the step function
                 | 1,  if d0 <= x < d1
                 | 2,  if d1 <= x < d2
    pw(x,data) = ...
                 | N, if d(N-1) <= x < dN
                 | 0, otherwise
    where di is the ith element in data.
    INPUT:      x   --  interval which the step function is defined over
              data  --  an ordered set of data (without repetitions)
    OUTPUT: pw_func --  an array of size x.shape[0]
    """
    vals = np.arange(1,udata.shape[0]+1).reshape(udata.shape[0],1)
    pw_func = np.sum(np.where(np.greater_equal(x,udata)*np.less(x,np.roll(udata,-1)),vals,0),axis=0)
    return pw_func


N = 50000
x = np.linspace(0,10,N)
data = [1,3,4,5,5,7]
udata = np.unique(data)

ti = time()
pw(x,udata)
tf = time()
print(tf - ti)

import cProfile
cProfile.run('pw(x,udata)')

The cProfile.run is telling me that most of the overhead is coming from np.where (about 1 ms) but I'd like to create faster code if possible. It seems that performing the operations row-wise versus column-wise makes some difference, unless I'm mistaken, but I think I've accounted for it. I know that sometimes list comprehensions can be faster but I couldn't figure out a faster way than what I'm doing using it.

Searchsorted seems to yield better performance but that 1 ms still remains on my computer:

(modified)
def pw(xx, uu):
    """
    Creates the step function
                 | 1,  if d0 <= x < d1
                 | 2,  if d1 <= x < d2
    pw(x,data) = ...
                 | N, if d(N-1) <= x < dN
                 | 0, otherwise
    where di is the ith element in data.
    INPUT:      x   --  interval which the step function is defined over
              data  --  an ordered set of data (without repetitions)
    OUTPUT: pw_func --  an array of size x.shape[0]
    """
    inds = np.searchsorted(uu, xx, side='right')
    vals = np.arange(1,uu.shape[0]+1)
    pw_func = vals[inds[inds != uu.shape[0]]]
    num_mins = np.sum(xx < np.min(uu))
    num_maxs = np.sum(xx > np.max(uu))

    pw_func = np.concatenate((np.zeros(num_mins), pw_func, np.zeros(xx.shape[0]-pw_func.shape[0]-num_mins)))
    return pw_func

This answer using piecewise seems pretty close, but that's on a scalar x0 and x1. How would I do it on arrays? And would it be more efficient?

Understandably, x may be pretty big but I'm trying to put it through a stress test.

I am still learning though so some hints or tricks that can help me out would be great.

EDIT

There seems to be a mistake in the second function since the resulting array from the second function doesn't match the first one (which I'm confident that it works):

N1 = pw1(x,udata.reshape(udata.shape[0],1)).shape[0]
N2 = np.sum(pw1(x,udata.reshape(udata.shape[0],1)) == pw2(x,udata))
print(N1 - N2)

yields

15000

data points that are not the same. So it seems that I don't know how to use 'searchsorted'.

EDIT 2

Actually I fixed it:

pw_func = vals[inds[inds != uu.shape[0]]]

was changed to

pw_func = vals[inds[inds[(inds != uu.shape[0])*(inds != 0)]-1]]

so at least the resulting arrays match. But the question still remains on whether there's a more efficient way of going about doing this.

EDIT 3

Thanks Tin Lai for pointing out the mistake. This one should work

pw_func = vals[inds[(inds != uu.shape[0])*(inds != 0)]-1]

Maybe a more readable way of presenting it would be

non_endpts = (inds != uu.shape[0])*(inds != 0) # only consider the points in between the min/max data values
shift_inds = inds[non_endpts]-1       # searchsorted side='right' includes the left end point and not right end point so a shift is needed
pw_func = vals[shift_inds]

I think I got lost in all those brackets! I guess that's the importance of readability.

A very abstract yet interesting problem! Thanks for entertaining me, I had fun :)

ps I'm not sure about your pw2 I wasn't able to get it output the same as pw1 .

For reference the original pw s:

def pw1(x, udata):
    vals = np.arange(1,udata.shape[0]+1).reshape(udata.shape[0],1)
    pw_func = np.sum(np.where(np.greater_equal(x,udata)*np.less(x,np.roll(udata,-1)),vals,0),axis=0)
    return pw_func

def pw2(xx, uu):
    inds = np.searchsorted(uu, xx, side='right')
    vals = np.arange(1,uu.shape[0]+1)
    pw_func = vals[inds[inds[(inds != uu.shape[0])*(inds != 0)]-1]]
    num_mins = np.sum(xx < np.min(uu))
    num_maxs = np.sum(xx > np.max(uu))

    pw_func = np.concatenate((np.zeros(num_mins), pw_func, np.zeros(xx.shape[0]-pw_func.shape[0]-num_mins)))
    return pw_func

My first attempt was utilising a lot of boardcasting operation from numpy :

def pw3(x, udata):
    # the None slice is to create new axis
    step_bool = x >= udata[None,:].T

    # we exploit the fact that bools are integer value of 1s
    # skipping the last value in "data"
    step_vals = np.sum(step_bool[:-1], axis=0)

    # for the step_bool that we skipped from previous step (last index)
    # we set it to zerp so that we can negate the step_vals once we reached
    # the last value in "data"
    step_vals[step_bool[-1]] = 0

    return step_vals

After looking at the searchsorted from your pw2 I had a new approach that utilise it with much higher performance:

def pw4(x, udata):
    inds = np.searchsorted(udata, x, side='right')

    # fix-ups the last data if x is already out of range of data[-1]
    if x[-1] > udata[-1]:
        inds[inds == inds[-1]] = 0

    return inds

Plots with:

plt.plot(pw1(x,udata.reshape(udata.shape[0],1)), label='pw1')
plt.plot(pw2(x,udata), label='pw2')
plt.plot(pw3(x,udata), label='pw3')
plt.plot(pw4(x,udata), label='pw4')

with data = [1,3,4,5,5,7] :

在此处输入图片说明

with data = [1,3,4,5,5,7,11]

在此处输入图片说明

pw1 , pw3 , pw4 are all identical

print(np.all(pw1(x,udata.reshape(udata.shape[0],1)) == pw3(x,udata)))
>>> True
print(np.all(pw1(x,udata.reshape(udata.shape[0],1)) == pw4(x,udata)))
>>> True

Performance: ( timeit by default runs 3 times, average of number=N of times)

print(timeit.Timer('pw1(x,udata.reshape(udata.shape[0],1))', "from __main__ import pw1, x, udata").repeat(number=1000))
>>> [3.1938983199979702, 1.6096494779994828, 1.962694135003403]
print(timeit.Timer('pw2(x,udata)', "from __main__ import pw2, x, udata").repeat(number=1000))
>>> [0.6884554479984217, 0.6075002400029916, 0.7799002879983163]
print(timeit.Timer('pw3(x,udata)', "from __main__ import pw3, x, udata").repeat(number=1000))
>>> [0.7369808239964186, 0.7557657590004965, 0.8088172269999632]
print(timeit.Timer('pw4(x,udata)', "from __main__ import pw4, x, udata").repeat(number=1000))
>>> [0.20514375300263055, 0.20203858999957447, 0.19906871100101853]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM