[英]Time performance of np.random.permutation, np.random.choice
與純python圖形理論庫中的類似MATLAB代碼相比,我遇到了一個時間性能很差的函數,因此我嘗試介紹此函數中的某些操作。
我跟蹤到以下結果
In [27]: timeit.timeit( 'permutation(138)[:4]', setup='from numpy.random import permutation', number=1000000)
Out[27]: 27.659916877746582
將此與MATLAB中的性能進行比較
>> tic; for i=1:1000000; randperm(138,4); end; toc
Elapsed time is 4.593305 seconds.
通過將其更改為np.random.choice
而不是我最初編寫的np.random.permutation
,我能夠顯着提高性能。
In [42]: timeit.timeit( 'choice(138, 4)', setup='from numpy.random import choice', number=1000000)
Out[42]: 18.9618501663208
但是它仍然沒有接近Matlab的性能。
是否有另一種方式可以在純python中獲得這種行為,而時間性能接近MATLAB時間性能?
基於this solution
是一個顯示如何模擬np.random.choice(..., replace=False)
的基於一招行為argsort
/ argpartition
,您可以重新創建MATLAB的randperm(138,4)
即與NumPy的np.random.choice(138,4, replace=False)
與np.argpartition
為:
np.random.rand(138).argpartition(range(4))[:4]
或者像這樣使用np.argsort
np.random.rand(138).argsort()[:4]
我們將這兩個版本的時間與MATLAB版本進行性能比較。
在MATLAB上-
>> tic; for i=1:1000000; randperm(138,4); end; toc
Elapsed time is 1.058177 seconds.
在帶有np.argpartition
NumPy上-
In [361]: timeit.timeit( 'np.random.rand(138).argpartition(range(4))[:4]', setup='import numpy as np', number=1000000)
Out[361]: 9.063489798831142
在帶有np.argsort
NumPy上-
In [362]: timeit.timeit( 'np.random.rand(138).argsort()[:4]', setup='import numpy as np', number=1000000)
Out[362]: 5.74625801707225
最初建議使用NumPy-
In [363]: timeit.timeit( 'choice(138, 4)', setup='from numpy.random import choice', number=1000000)
Out[363]: 6.793723535243771
似乎可以使用np.argsort
來提高性能。
這需要多長時間? 我估計需要1-2秒。
def four():
k = np.random.randint(138**4)
a = k % 138
b = k // 138 % 138
c = k // 138**2 % 138
d = k // 138**3 % 138
return (a, b, c, d) if a != b and a != c and a != d and b != c and b != d and c != d else four()
更新1:最初,我使用random.randrange
,但np.random.randint
使整個過程快了兩倍。
更新2:由於NumPy的隨機函數似乎要快得多,所以我嘗試了這一點,這是另一個快〜1.33的因素:
>>> def four():
a = randint(138)
b = randint(138)
c = randint(138)
d = randint(138)
return (a, b, c, d) if a != b and a != c and a != d and b != c and b != d and c != d else four()
>>> import timeit
>>> from numpy.random import randint
>>> timeit.timeit(lambda: four(), number=1000000)
2.3742770821572776
這比原始速度快22倍:
>>> timeit.timeit('permutation(138)[:4]', setup='from numpy.random import permutation', number=1000000)
51.80568455893672
(字符串與lambda
區別不大)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.