繁体   English   中英

在 numpy 中随机索引多维 arrays

[英]Randomly Indexing Multi-dimensional arrays in numpy

如何更改此代码避免使用 python for-loop,但 numpy function

h, w = 2, 2
im = np.random.randint(255, size=(h, w, 3))
index = np.random.randint(3, size=(h, w))
number = np.random.randint(255, size=(h, w))
for i in range(h):
    for j in range(w):
        im[i, j, index[i, j]] += number[i, j]

下面的代码与您的原始代码相同,但避免使用for循环。 但是,您为什么要这样做尚不清楚,因为我认为这个解决方案肯定比原来的解决方案更糟糕。

from timeit import timeit
import numpy as np

h, w = 20, 20
im = np.random.randint(255, size=(h, w, 3))


def increase_random(x):
    result = np.copy(x)
    result[np.random.randint(3)] += np.random.randint(255)
    return result


def loops():
    index = np.random.randint(3, size=(h, w))
    number = np.random.randint(255, size=(h, w))
    for i in range(h):
        for j in range(w):
            im[i, j, index[i, j]] += number[i, j]


def vectorized():
    irv(im)


irv = np.vectorize(increase_random, signature='(n)->(n)')

print(timeit(vectorized, number=10))
print(timeit(loops, number=10))

我添加了一些时间测量来表明,在这种情况下,矢量化对提高性能没有任何作用。 在我的机器上, loops代码快了大约 25 倍。

但是,如果您正在执行的操作更简单或更复杂但更易于优化,则它可能会从矢量化中受益。 碰巧您的示例不太可能从中受益,而循环相当小且有效。

对于像您这样小的示例,您将很难加快循环代码的速度。 numpy支付从某个问题大小向上产生的开销:

在此处输入图像描述

plot 显示原始循环代码 ( OP ) 与矢量化代码 ( pp ) 在总像素数wxh上的执行时间。

它是使用以下方法生成的:

from simple_benchmark import BenchmarkBuilder, MultiArgument
import numpy as np
from scipy.misc import face

B = BenchmarkBuilder()

@B.add_function()
def OP(im,index,number):
    im = im.copy()
    h,w,_ = im.shape
    for i in range(h):
        for j in range(w):
            im[i, j, index[i, j]] += number[i, j]
    return im

@B.add_function()
def pp(im,index,number):
    im = im.copy()
    h,w,_ = im.shape
    h,w = np.ogrid[:h,:w]
    im[h,w,index] += number
    return im

@B.add_arguments('#pixels')
def argument_provider():
    im = face()
    h,w,_ = im.shape
    mh,mw = h//2,w//2
    for exp in range(-8,1):
        fr = 2.**exp
        dh,dw = int(fr*mh),int(fr*mw)
        index = np.random.randint(3, size=(2*dh, 2*dw))
        number = np.random.randint(255, size=(2*dh, 2*dw),dtype=im.dtype)
        yield 4*dh*dw,MultiArgument([im[mh-dh:mh+dh,mw-dw:mw+dw],index,number])

r = B.run()
r.plot()

import pylab
pylab.savefig('randomchannel.png')

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM