简体   繁体   中英

Speeding up summation for loop in python

I have the following bottleneck and am wondering if anyone can suggest ways to speed it up.

I have three lists x,y,z of length N . and I apply the following summation .

def abs_val_diff(x1, x2, x3, y1, y2, y3):
    """ Find the absolute value of the difference between x and y """
    return py.sqrt((x1 - y1) ** 2.0 + (x2 - y2) ** 2.0 + (x3 - y3) ** 2.0)

R = 0.1  
sumV = 0.0
for i in xrange(N):
    for j in xrange(i + 1, N):
        if R > abs_val_diff(x[i], y[i], z[i],
                            x[j], y[j], z[j]):
                sumV += 1.0

I have tried using numpy arrays, but either I am doing something wrong or there is a reduction in speed of about a factor of 2.

Any ideas would be highly appreciated.

I believe you can utilize numpy a little more efficiently by doing something like the following. Make a small modification to your function to use the numpy.sqrt:

import numpy as np

def abs_val_diff(x1, x2, x3, y1, y2, y3):
    """ Find the absolute value of the difference between x and y """
    return np.sqrt((x1 - y1) ** 2.0 + (x2 - y2) ** 2.0 + (x3 - y3) ** 2.0)

Then call with the full arrays:

res = abs_val_diff(x[:-1],y[:-1],z[:-1],x[1:],y[1:],z[1:])

Then, because you're adding 1 for each match, you can simply take the length of the array resulting from a query against the result:

sumV = len(res[R>res])

This lets numpy handle the iteration. Hopefully that works for you

Is there any reason you actually need to take the square root in your function? If all you do with the result is to compare it against a limit why not just square both sides of the comparison?

def abs_val_diff_squared(x1, x2, x3, y1, y2, y3):
    """ Find the square of the absolute value of the difference between x and y """
    return (x1 - y1) ** 2.0 + (x2 - y2) ** 2.0 + (x3 - y3) ** 2.0

R = 0.1
R_squared = R * R
sumV = 0.0
for i in xrange(N):
    for j in xrange(i + 1, N):
        if R_squared > abs_val_diff_squared(x[i], y[i], z[i],
                            x[j], y[j], z[j]):
                sumV += 1.0

I also feel there ought to be much bigger savings gained from sorting the data into something like an octtree so you only have to look at nearby points rather than comparing everything against everything, but that's outside my knowledge.

It turns out long, ugly, list comprehensions are generally faster than explicit loops in python because they can be compiled to more efficient bytecode. I'm not sure if it'll help for you, but try something like this:

sumV = sum((1.0 for j in xrange(1+1, N) for i in xrange(N) if R > abs_val_diff(x[i], y[i], z[i], x[j], y[j], z[j])))

Yes, it looks absolutely atrocious, but there you go. More info can be found here and here .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM