简体   繁体   中英

Optimize Numba and Numpy function

I'm trying to make this piece of code to run faster, but I can't find any more tricks that could speed this up.

I get a runtime of about 3 microseconds, the issue is that I'm calling this function a couple of million times and the process just ends up taking to long. I have the same implementation in Java (with just basic for loops) and basically, the computations are instantaneous for even large training data (This is for an ANN)

Is there a way to speed this up?

I'm running Python 2.7, numba 0.43.1 and numpy 1.16.3 on Windows 10

x = True
expected = 0.5
eligibility = np.array([0.1,0.1,0.1])
positive_weight = np.array([0.2,0.2,0.2])
total_sq_grad_positive = np.array([0.1,0.1,0.1])
learning_rate = 1

@nb.njit(fastmath= True, cache = True)
def update_weight_from_post_post_jit(x, expected,eligibility,positive_weight,total_sq_grad_positive,learning_rate):
        if x:
            g = np.multiply(eligibility,(1-expected))
        else:
            g = np.negative(np.multiply(eligibility,expected))
        gg = np.multiply(g,g)
        total_sq_grad_positive = np.add(total_sq_grad_positive,gg)
        #total_sq_grad_positive = np.where(divide_by_zero,total_sq_grad_positive, tsgp_temp)

        temp = np.multiply(learning_rate, g)
        temp2 = np.sqrt(total_sq_grad_positive)
        #temp2 = np.where(temp2 == 0,1,temp2 )
        temp2[temp2 == 0] = 1
        temp = np.divide(temp,temp2)
        positive_weight = np.add(positive_weight, temp)
        return [positive_weight, total_sq_grad_positive]

Edit: It seems that @max9111 is right. Unnecessary temporary arrays is where the overhead comes from.

For the current semantics of your function, there seems to be two temporary arrays that cannot be avoided --- the return values [positive_weight, total_sq_grad_positive] . However, it struck me that you may be planning to use this function to update those two input arrays. If so, by doing everything in-place we get the most speedup. Like this:

import numba as nb
import numpy as np

x = True
expected = 0.5
eligibility = np.array([0.1,0.1,0.1])
positive_weight = np.array([0.2,0.2,0.2])
total_sq_grad_positive = np.array([0.1,0.1,0.1])
learning_rate = 1

@nb.njit(fastmath= True, cache = True)
def update_weight_from_post_post_jit(x, expected,eligibility,positive_weight,total_sq_grad_positive,learning_rate):
    for i in range(eligibility.shape[0]):
        if x:
            g = eligibility[i] * (1-expected)
        else:
            g = -(eligibility[i] * expected)
        gg = g * g
        total_sq_grad_positive[i] = total_sq_grad_positive[i] + gg

        temp = learning_rate * g
        temp2 = np.sqrt(total_sq_grad_positive[i])
        if temp2 == 0: temp2 = 1
        temp = temp / temp2
        positive_weight[i] = positive_weight[i] + temp

@nb.jit
def test(n, *args):
    for i in range(n): update_weight_from_post_post_jit(*args)

If updating the input arrays is not what you want, you can begin the function with

positive_weight = positive_weight.copy()
total_sq_grad_positive = total_sq_grad_positive.copy()

and return them as in your original code. This is not nearly as fast, but still faster.


I am not sure whether it can be optimized to be "instantaneous"; I am a little surprised that Java could do it since this looks like a pretty complicated function to me, with time-consuming operations like sqrt .

But, did you use nb.jit on the function(s) calling this function? Like this:

 @nb.jit def test(n): for i in range(n): update_weight_from_post_post_jit(x, expected,eligibility,positive_weight,total_sq_grad_positive,learning_rate) 

On my computer, this cuts the running time in half, which makes sense since Python function calls have a really high overhead.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM