简体   繁体   中英

Optimizing array operations in Python with Numpy

I'm baffled. I just ported my code from Java to Python. Goods news is the Python alternative for the lib I'm using is much quicker. Bad part is that my custom processing code is much slower with the Python alternative I wrote :( I even removed some parts I deemed unnecessary, still much slower. The Java version took about half a second, Python takes 5-6.

rimg1 = imageio.imread('test1.png').astype(np.uint8)
rimg2 = imageio.imread('test2.png').astype(np.uint8)

sum_time = 0
for offset in range(-left, right):
    rdest = np.zeros((h, w, 3)).astype(np.uint8)

    if offset == 0:
        continue

    mult = np.uint8(1.0 / (offset * multiplier / frames))
    for y in range(h):
        for x in range(0, w - backup, 1):
            slice_time = time.time()
            src = rimg2[y,x] // mult + 1
            sum_time += time.time() - slice_time

            pix = rimg1[y,x + backup]

w ~= 384 and h ~= 384 src ranges from 0 - 30 usually. left to right is -5 to 5

How come sum_time takes about a third of my total time?

Edit

With the help of josephjscheidt I made some changes.

mult = np.uint8(1.0 / (offset * multiplier / frames))
multArray = np.floor_divide(rimg2, mult) + 1
for y in range(h):
    pixy = rimg1[y]
    multy = multArray[y]
    for x in range(0, w - backup, 1):
        src = multy[y]
        slice_time = time.time()
        pix = pixy[x + backup]
        sum_time += time.time() - slice_time
        ox = x
        for o in range(src):
            if ox < 0:
                break

            rdest[y,ox] = pix
            ox-=1

Using the numpy iterator for the srcArray cuts total time almost in half! The numpy operation itself seems to take negligible time.

Now most of the time taken is in rimg1 lookup

pix = rimg1[x + backup]

and the inner for loop (both taking 50% of time). Is it possible to handle this with numpy operations as well?

Edit

I would figure rewriting it could be of benefit, but somehow the following actually takes a little bit longer:

    for x in range(0, w - backup, 1):
        slice_time = time.time()
        lastox = max(x-multy[y], 0)
        rdest[y,lastox:x] = pixy[x + backup]
        sum_time += time.time() - slice_time

Edit

            slice_time = time.time()
            depth = multy[y]
            pix = pixy[x + backup]

            ox = x

            #for o in range(depth):
            #    if ox < 0:
            #        break;
            #
            #    rdesty[ox] = pix
            #    ox-=1

            # if I uncomment the above lines, and comment out the following two
            # it takes twice as long!
            lastox = max(x-multy[y], 0)
            rdesty[lastox:x] = pixy[x + backup]

            sum_time += time.time() - slice_time

The python interpreter is strange..

Time taken is now 2.5 seconds for sum_time. In comparison, Java does it in 60ms

For loops are notoriously slow with numpy arrays, and you have a three-layer for loop here. The underlying concept with numpy arrays is to perform operations on the entire array at once, rather than trying to iterate over them.

Although I can't entirely interpret your code, because most of the variables are undefined in the code chunk you provided, I'm fairly confident you can refactor here and vectorize your commands to remove the loops. For instance, if you redefine offset as a one-dimensional array, then you can calculate all values of mult at once without having to invoke a for loop: mult will become a one-dimensional array holding the correct values. We can avoid dividing by zero using the out argument (setting the default output to the offset array) and where argument (performing the calculation only where offset doesn't equal zero):

mult = np.uint8(np.divide(1.0, (offset * multiplier / frames),
                          out = offset, where = (offset != 0))

Then, to use the mult array on the rimg2 row by row, you can use a broadcasting trick (here, I'm assuming you want to add one to each element in rimg2):

src = np.floor_divide(rimg2, mult[:,None], out = rimg2, where = (mult != 0)) + 1

I found this article extremely helpful when learning how to effectively work with numpy arrays:

https://realpython.com/numpy-array-programming/

Since you are working with images, you may want to especially pay attention to the section on image feature extraction and stride_tricks. Anyway, I hope this helps you get started.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM