簡體   English   中英

numba 中兩個列表的交集

[英]Intersection of two lists in numba

我想知道在 numba 函數中計算兩個列表交集的最快方法。 只是為了澄清:兩個列表交集的例子:

Input : 
lst1 = [15, 9, 10, 56, 23, 78, 5, 4, 9]
lst2 = [9, 4, 5, 36, 47, 26, 10, 45, 87]
Output :
[9, 10, 4, 5]

問題是,這需要在 numba 函數中計算,因此不能使用例如集合。 你有想法嗎? 我當前的代碼非常基礎。 我認為還有改進的余地。

@nb.njit
def intersection:
   result = []
   for element1 in lst1:
      for element2 in lst2:
         if element1 == element2:
            result.append(element1)
   ....

由於 numba 以機器代碼編譯和運行您的代碼,因此您可能最適合這種簡單的操作。 我在下面運行了一些基准測試

@nb.njit
def loop_intersection(lst1, lst2):
    result = []
    for element1 in lst1:
        for element2 in lst2:
            if element1 == element2:
                result.append(element1)
    return result

@nb.njit
def set_intersect(lst1, lst2):
    return set(lst1).intersection(set(lst2))

結果

loop_intersection
40.4 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

set_intersect
42 µs ± 6.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

我玩了一下,嘗試學習一些東西,意識到答案已經給出。 當我運行接受的答案時,我得到 [9, 10, 5, 4, 9] 的返回值。 我不清楚重復的 9 是否可以接受。 假設沒問題,我使用列表理解進行了一次試驗,看看它有什么不同。 我的結果:

from numba import jit

def createLists():
    l1 = [15, 9, 10, 56, 23, 78, 5, 4, 9]
    l2 = [9, 4, 5, 36, 47, 26, 10, 45, 87]

@jit
def listComp():
    l1, l2 = createLists()
    return [i for i in l1 for j in l2 if i == j]

%timeit listComp() 5.84 微秒 +/- 10.5 納秒

或者,如果您可以使用 Numpy,則此代碼會更快並刪除重復的“9”,並且使用 Numba 簽名會更快。

import numpy as np
from numba import jit, int64

@jit(int64[:](int64[:], int64[:]))
def JitListComp(l1, l2):
    l3 = np.array([i for i in l1 for j in l2 if i == j])
    return np.unique(l3) # and i not in crossSec]

@jit
def CreateList():
    l1 = np.array([15, 9, 10, 56, 23, 78, 5, 4, 9])
    l2 = np.array([9, 4, 5, 36, 47, 26, 10, 45, 87])
    return JitListComp(l1, l2)

CreateList()
Out[39]: array([ 4,  5,  9, 10])

%timeit CreateList()
1.71 µs ± 10.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

您可以為此使用設置操作:

def intersection(lst1, lst2): 
    return list(set(lst1) & set(lst2))

然后簡單地調用函數intersection(lst1,lst2) 這將是最簡單的方法。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM