[英]Improve performance in lists
我有一個問題,我試圖采取隨機排序列表,我想知道有多少索引比當前元素更大的元素值比當前元素小。
例如:
[1,2,5,3,7,6,8,4]
應該返回:
[0,0,2,0,2,1,1,0]
這是我目前正在使用的代碼。
bribe_array = [0] * len(q)
for i in range(0, len(bribe_array)-1):
bribe_array[i] = sum(j<q[i] for j in q[(i+1):])
這確實產生了所需的陣列,但運行緩慢。 實現這一目標的更多pythonic方法是什么?
我們可以解決問題中的代碼,但它仍然是一個O(n^2)
算法。 要真正提高性能,不是要使實現或多或少pythonic,而是使用輔助數據結構的不同方法。
下面是O(n log n)
解決方案的概述:實現自平衡BST ( AVL或紅黑是很好的選項),並在每個節點中另外存儲一個屬性,該屬性具有以其為根的子樹的大小。 現在從右到左遍歷列表,並將其所有元素作為新節點插入樹中。 我們還需要一個與輸入列表大小相同的額外輸出列表來跟蹤答案。
對於我們在樹中插入的每個節點,我們將其鍵與根進行比較。 如果它大於根中的值,則意味着它大於左子樹中的所有節點,因此我們需要將左子樹的大小添加到我們嘗試的元素位置的答案列表中插入。
我們繼續遞歸地執行此操作並更新我們訪問的每個節點中的size屬性,直到找到插入新節點的正確位置,然后繼續執行輸入列表中的下一個元素。 最后,輸出列表將包含答案。
另一個比實現平衡BST簡單得多的選擇是調整合並排序以計算反轉並在過程中累積它們。 顯然,任何單個交換都是反轉,因此較低索引的元素獲得一個計數。 然后在合並遍歷期間,只需跟蹤右側組中有多少元素向左移動,並為添加到右側組的元素添加該計數。
這是一個非常粗略的插圖:)
[1,2,5,3,7,6,8,4]
sort 1,2 | 5,3
3,5 -> 5: 1
merge
1,2,3,5
sort 7,6 | 8,4
6,7 -> 7: 1
4,8 -> 8: 1
merge
4 -> 6: 1, 7: 2
4,6,7,8
merge 1,2,3,5 | 4,6,7,8
1,2,3,4 -> 1 moved
5 -> +1 -> 5: 2
6,7,8
有幾種方法可以在不影響整體計算復雜性的情況下加速代碼。
這是因為有幾種方法可以編寫這種算法。
讓我們從你的代碼開始:
def bribe_orig(q):
bribe_array = [0] * len(q)
for i in range(0, len(bribe_array)-1):
bribe_array[i] = sum(j<q[i] for j in q[(i+1):])
return bribe_array
這有點混合風格:首先,你生成一個零列表(這不是真正需要的,因為你可以按需添加項目;其次,外部列表使用一個range()
,這是你想要的次優多次訪問特定項目,因此本地名稱會更快;第三,你在sum()
寫一個生成器也是次優的,因為它將總結布爾值,因此一直執行隱式轉換。
更清潔的方法是:
def bribe(items):
result = []
for i, item in enumerate(items):
partial_sum = 0
for x in items[i + 1:]:
if x < item:
partial_sum += 1
result.append(partial_sum)
return result
這有點簡單,因為它明確地做了很多事情,並且只在必要時執行求和(因此當你添加0時跳過),它可能會更快。
另一種以更緊湊的方式編寫代碼的方法是:
def bribe_compr(items):
return [sum(x < item for x in items[i + 1:]) for i, item in enumerate(items)]
這涉及使用生成器和列表推導,但外部循環使用enumerate()
按照典型的Python樣式編寫。
但Python在原始循環中聲名狼借,因此在可能的情況下,矢量化可能會有所幫助。 這樣做的一種方法(僅適用於內循環)是numpy
:
import numpy as np
def bribe_np(items):
items = np.array(items)
return [np.sum(items[i + 1:] < item) for i, item in enumerate(items)]
最后,可以使用JIT編譯器來加速使用Numba的普通Python循環:
import numba
bribe_jit = nb.jit(bribe)
至於任何JIT,它都有一些即時編譯的成本,最終會因足夠大的循環而被抵消。 不幸的是,Numba的JIT並不支持所有Python代碼,但是當它發生時,就像在這種情況下一樣,它可能非常有價值。
我們來看看一些數字。
考慮使用以下內容生成的輸入:
import numpy as np
np.random.seed(0)
n = 10
q = np.random.randint(1, n, n)
在小尺寸輸入( n = 10
)上:
%timeit bribe_orig(q)
# 228 µs ± 3.56 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit bribe(q)
# 20.3 µs ± 814 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit bribe_compr(q)
# 216 µs ± 5.32 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit bribe_np(q)
# 133 µs ± 9.16 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit bribe_jit(q)
# 1.11 µs ± 17.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
在中等大小的輸入上( n = 100
):
%timeit bribe_orig(q)
# 20.5 ms ± 398 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit bribe(q)
# 741 µs ± 11.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit bribe_compr(q)
# 18.9 ms ± 202 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit bribe_np(q)
# 1.22 ms ± 27.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit bribe_jit(q)
# 7.54 µs ± 165 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
在較大的輸入( n = 10000
)上:
%timeit bribe_orig(q)
# 1.99 s ± 19.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit bribe(q)
# 60.6 ms ± 280 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit bribe_compr(q)
# 1.8 s ± 11.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit bribe_np(q)
# 12.8 ms ± 32.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit bribe_jit(q)
# 182 µs ± 2.66 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
從這些結果中,我們觀察到我們從使用僅涉及Python循環的顯式構造替換sum()
獲得了最多。 理解的使用不會讓你超過約。 比您的代碼提高10%。 對於更大的輸入,NumPy的使用甚至可以比僅涉及Python循環的顯式構造更快。 但是,當您使用Numba的JITed版本的bribe()
時,您將獲得真正的優惠。
通過逐步構建數組中從上到下的排序列表,可以獲得更好的性能。 在排序列表中對數組中的每個元素使用二進制搜索算法,您將獲得插入元素的索引,這也恰好是已經處理的元素中較小的元素數。
收集這些插入點將為您提供預期的結果(以相反的順序)。
這是一個例子:
a = [1,2,5,3,7,6,8,4]
from bisect import bisect_left
s = []
r = []
for x in reversed(a):
p = bisect_left(s,x)
r.append(p)
s.insert(p,x)
r = r[::-1]
print(r) #[0,0,2,0,2,1,1]
對於此示例,進展如下:
step 1: x = 4, p=0 ==> r=[0] s=[4]
step 2: x = 8, p=1 ==> r=[0,1] s=[4,8]
step 3: x = 6, p=1 ==> r=[0,1,1] s=[4,6,8]
step 4: x = 7, p=2 ==> r=[0,1,1,2] s=[4,6,7,8]
step 5: x = 3, p=0 ==> r=[0,1,1,2,0] s=[3,4,6,7,8]
step 6: x = 5, p=2 ==> r=[0,1,1,2,0,2] s=[3,4,5,6,7,8]
step 7: x = 2, p=0 ==> r=[0,1,1,2,0,2,0] s=[2,3,4,5,6,7,8]
step 8: x = 1, p=0 ==> r=[0,1,1,2,0,2,0,0] s=[1,2,3,4,5,6,7,8]
Reverse r, r = r[::-1] r=[0,0,2,0,2,1,1,0]
您將執行N個循環(數組的大小),二進制搜索在log( i )中執行,其中i為1到N.因此,小於O(N * log(N))。 唯一需要注意的是s.insert(p,x)的性能,它將根據原始列表的順序引入一些可變性。 總體而言,當陣列已經排序時,性能簡檔應該在O(N)和O(N * log(N))之間,最壞情況是O(n ^ 2)。
如果您只需要使代碼更快更簡潔,則可以使用列表推導(但仍然是O(n ^ 2)時間):
r = [sum(v<p for v in a[i+1:]) for i,p in enumerate(a)]
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.