[英]Fastest way to sort in Python (no cython)
我有一個問題,我要用自定義函數對一個非常大的數組(形狀 - 7900000X4X4)進行排序。 我使用了sorted
但sorted
花了1個多小時。 我的代碼是這樣的。
def compare(x,y):
print('DD '+str(x[0]))
if(np.array_equal(x[1],y[1])==True):
return -1
a = x[1].flatten()
b = y[1].flatten()
idx = np.where( (a>b) != (a<b) )[0][0]
if a[idx]<0 and b[idx]>=0:
return 0
elif b[idx]<0 and a[idx]>=0:
return 1
elif a[idx]<0 and b[idx]<0:
if a[idx]>b[idx]:
return 0
elif a[idx]<b[idx]:
return 1
elif a[idx]<b[idx]:
return 1
else:
return 0
def cmp_to_key(mycmp):
class K:
def __init__(self, obj, *args):
self.obj = obj
def __lt__(self, other):
return mycmp(self.obj, other.obj)
return K
tblocks = sorted(tblocks.items(),key=cmp_to_key(compare))
這有效,但我希望它能在幾秒鍾內完成。 我認為在python中沒有任何直接實現可以給我我需要的性能,所以我嘗試了cython。 我的Cython代碼就是這個,非常簡單。
cdef int[:,:] arrr
cdef int size
cdef bool compare(int a,int b):
global arrr,size
cdef int[:] x = arrr[a]
cdef int[:] y = arrr[b]
cdef int i,j
i = 0
j = 0
while(i<size):
if((j==size-1)or(y[j]<x[i])):
return 0
elif(x[i]<y[j]):
return 1
i+=1
j+=1
return (j!=size-1)
def sorted(np.ndarray boxes,int total_blocks,int s):
global arrr,size
cdef int i
cdef vector[int] index = xrange(total_blocks)
arrr = boxes
size = s
sort(index.begin(),index.end(),compare)
return index
cython中的這段代碼用了33秒! Cython是解決方案,但我正在尋找一些可以直接在python上運行的替代解決方案。 例如numba。 我嘗試了Numba,但我沒有得到令人滿意的結果。 請幫忙!
沒有一個有效的例子,很難給出答案。 我假設你的Cython代碼中的arrr是一個2D數組,我假設這個大小是size=arrr.shape[0]
Numba實施
import numpy as np
import numba as nb
from numba.targets import quicksort
def custom_sorting(compare_fkt):
index_arange=np.arange(size)
quicksort_func=quicksort.make_jit_quicksort(lt=compare_fkt,is_argsort=False)
jit_sort_func=nb.njit(quicksort_func.run_quicksort)
index=jit_sort_func(index_arange)
return index
def compare(a,b):
x = arrr[a]
y = arrr[b]
i = 0
j = 0
while(i<size):
if((j==size-1)or(y[j]<x[i])):
return False
elif(x[i]<y[j]):
return True
i+=1
j+=1
return (j!=size-1)
arrr=np.random.randint(-9,10,(7900000,8))
size=arrr.shape[0]
index=custom_sorting(compare)
這為生成的測試數據提供了3.85秒 。 但排序算法的速度在很大程度上取決於數據....
簡單的例子
import numpy as np
import numba as nb
from numba.targets import quicksort
#simple reverse sort
def compare(a,b):
return a > b
#create some test data
arrr=np.array(np.random.rand(7900000)*10000,dtype=np.int32)
#we can pass the comparison function
quicksort_func=quicksort.make_jit_quicksort(lt=compare,is_argsort=True)
#compile the sorting function
jit_sort_func=nb.njit(quicksort_func.run_quicksort)
#get the result
ind_sorted=jit_sort_func(arrr)
此實現比np.argsort慢約35%,但這在編譯代碼中使用np.argsort時也很常見。
如果我正確地理解了你的代碼,那么你所考慮的順序就是標准順序,只是從0
開始在+/-infinity
,最大值在-0
。 最重要的是,我們有簡單的從左到右的詞典順序。
現在,如果您的數組dtype是整數,請觀察以下內容:由於負數的補碼表示,視圖轉換為unsigned int使您的訂單成為標准訂單。 最重要的是,如果我們使用大端編碼,可以通過視圖轉換為void
dtype來實現有效的詞典排序。
下面的代碼顯示使用10000x4x4
示例,此方法提供與Python代碼相同的結果。
它還在7,900,000x4x4
示例(使用數組,而不是dict)上對其進行基准測試。 在我適度的筆記本電腦上,此方法需要8
秒鍾。
import numpy as np
def compare(x, y):
# print('DD '+str(x[0]))
if(np.array_equal(x[1],y[1])==True):
return -1
a = x[1].flatten()
b = y[1].flatten()
idx = np.where( (a>b) != (a<b) )[0][0]
if a[idx]<0 and b[idx]>=0:
return 0
elif b[idx]<0 and a[idx]>=0:
return 1
elif a[idx]<0 and b[idx]<0:
if a[idx]>b[idx]:
return 0
elif a[idx]<b[idx]:
return 1
elif a[idx]<b[idx]:
return 1
else:
return 0
def cmp_to_key(mycmp):
class K:
def __init__(self, obj, *args):
self.obj = obj
def __lt__(self, other):
return mycmp(self.obj, other.obj)
return K
def custom_sort(a):
assert a.dtype==np.int64
b = a.astype('>i8', copy=False)
return b.view(f'V{a.dtype.itemsize * a.shape[1]}').ravel().argsort()
tblocks = np.random.randint(-9,10, (10000, 4, 4))
tblocks = dict(enumerate(tblocks))
tblocks_s = sorted(tblocks.items(),key=cmp_to_key(compare))
tblocksa = np.array(list(tblocks.values()))
tblocksa = tblocksa.reshape(tblocksa.shape[0], -1)
order = custom_sort(tblocksa)
tblocks_s2 = list(tblocks.items())
tblocks_s2 = [tblocks_s2[o] for o in order]
print(tblocks_s == tblocks_s2)
from timeit import timeit
data = np.random.randint(-9_999, 10_000, (7_900_000, 4, 4))
print(timeit(lambda: data[custom_sort(data.reshape(data.shape[0], -1))],
number=5) / 5)
樣本輸出:
True
7.8328493310138585
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.