[英]Time performance of np.random.permutation, np.random.choice
与纯python图形理论库中的类似MATLAB代码相比,我遇到了一个时间性能很差的函数,因此我尝试介绍此函数中的某些操作。
我跟踪到以下结果
In [27]: timeit.timeit( 'permutation(138)[:4]', setup='from numpy.random import permutation', number=1000000)
Out[27]: 27.659916877746582
将此与MATLAB中的性能进行比较
>> tic; for i=1:1000000; randperm(138,4); end; toc
Elapsed time is 4.593305 seconds.
通过将其更改为np.random.choice
而不是我最初编写的np.random.permutation
,我能够显着提高性能。
In [42]: timeit.timeit( 'choice(138, 4)', setup='from numpy.random import choice', number=1000000)
Out[42]: 18.9618501663208
但是它仍然没有接近Matlab的性能。
是否有另一种方式可以在纯python中获得这种行为,而时间性能接近MATLAB时间性能?
基于this solution
是一个显示如何模拟np.random.choice(..., replace=False)
的基于一招行为argsort
/ argpartition
,您可以重新创建MATLAB的randperm(138,4)
即与NumPy的np.random.choice(138,4, replace=False)
与np.argpartition
为:
np.random.rand(138).argpartition(range(4))[:4]
或者像这样使用np.argsort
np.random.rand(138).argsort()[:4]
我们将这两个版本的时间与MATLAB版本进行性能比较。
在MATLAB上-
>> tic; for i=1:1000000; randperm(138,4); end; toc
Elapsed time is 1.058177 seconds.
在带有np.argpartition
NumPy上-
In [361]: timeit.timeit( 'np.random.rand(138).argpartition(range(4))[:4]', setup='import numpy as np', number=1000000)
Out[361]: 9.063489798831142
在带有np.argsort
NumPy上-
In [362]: timeit.timeit( 'np.random.rand(138).argsort()[:4]', setup='import numpy as np', number=1000000)
Out[362]: 5.74625801707225
最初建议使用NumPy-
In [363]: timeit.timeit( 'choice(138, 4)', setup='from numpy.random import choice', number=1000000)
Out[363]: 6.793723535243771
似乎可以使用np.argsort
来提高性能。
这需要多长时间? 我估计需要1-2秒。
def four():
k = np.random.randint(138**4)
a = k % 138
b = k // 138 % 138
c = k // 138**2 % 138
d = k // 138**3 % 138
return (a, b, c, d) if a != b and a != c and a != d and b != c and b != d and c != d else four()
更新1:最初,我使用random.randrange
,但np.random.randint
使整个过程快了两倍。
更新2:由于NumPy的随机函数似乎要快得多,所以我尝试了这一点,这是另一个快〜1.33的因素:
>>> def four():
a = randint(138)
b = randint(138)
c = randint(138)
d = randint(138)
return (a, b, c, d) if a != b and a != c and a != d and b != c and b != d and c != d else four()
>>> import timeit
>>> from numpy.random import randint
>>> timeit.timeit(lambda: four(), number=1000000)
2.3742770821572776
这比原始速度快22倍:
>>> timeit.timeit('permutation(138)[:4]', setup='from numpy.random import permutation', number=1000000)
51.80568455893672
(字符串与lambda
区别不大)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.