[英]Intersection of two lists in numba
我想知道在 numba 函數中計算兩個列表交集的最快方法。 只是為了澄清:兩個列表交集的例子:
Input :
lst1 = [15, 9, 10, 56, 23, 78, 5, 4, 9]
lst2 = [9, 4, 5, 36, 47, 26, 10, 45, 87]
Output :
[9, 10, 4, 5]
問題是,這需要在 numba 函數中計算,因此不能使用例如集合。 你有想法嗎? 我當前的代碼非常基礎。 我認為還有改進的余地。
@nb.njit
def intersection:
result = []
for element1 in lst1:
for element2 in lst2:
if element1 == element2:
result.append(element1)
....
由於 numba 以機器代碼編譯和運行您的代碼,因此您可能最適合這種簡單的操作。 我在下面運行了一些基准測試
@nb.njit
def loop_intersection(lst1, lst2):
result = []
for element1 in lst1:
for element2 in lst2:
if element1 == element2:
result.append(element1)
return result
@nb.njit
def set_intersect(lst1, lst2):
return set(lst1).intersection(set(lst2))
結果
loop_intersection
40.4 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
set_intersect
42 µs ± 6.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
我玩了一下,嘗試學習一些東西,意識到答案已經給出。 當我運行接受的答案時,我得到 [9, 10, 5, 4, 9] 的返回值。 我不清楚重復的 9 是否可以接受。 假設沒問題,我使用列表理解進行了一次試驗,看看它有什么不同。 我的結果:
from numba import jit
def createLists():
l1 = [15, 9, 10, 56, 23, 78, 5, 4, 9]
l2 = [9, 4, 5, 36, 47, 26, 10, 45, 87]
@jit
def listComp():
l1, l2 = createLists()
return [i for i in l1 for j in l2 if i == j]
%timeit listComp() 5.84 微秒 +/- 10.5 納秒
或者,如果您可以使用 Numpy,則此代碼會更快並刪除重復的“9”,並且使用 Numba 簽名會更快。
import numpy as np
from numba import jit, int64
@jit(int64[:](int64[:], int64[:]))
def JitListComp(l1, l2):
l3 = np.array([i for i in l1 for j in l2 if i == j])
return np.unique(l3) # and i not in crossSec]
@jit
def CreateList():
l1 = np.array([15, 9, 10, 56, 23, 78, 5, 4, 9])
l2 = np.array([9, 4, 5, 36, 47, 26, 10, 45, 87])
return JitListComp(l1, l2)
CreateList()
Out[39]: array([ 4, 5, 9, 10])
%timeit CreateList()
1.71 µs ± 10.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
您可以為此使用設置操作:
def intersection(lst1, lst2):
return list(set(lst1) & set(lst2))
然后簡單地調用函數intersection(lst1,lst2)
。 這將是最簡單的方法。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.