简体   繁体   English

numba 中两个列表的交集

[英]Intersection of two lists in numba

I would like to know the fastest way to compute the intersection of two list within a numba function.我想知道在 numba 函数中计算两个列表交集的最快方法。 Just for clarification: an example of the intersection of two lists:只是为了澄清:两个列表交集的例子:

Input : 
lst1 = [15, 9, 10, 56, 23, 78, 5, 4, 9]
lst2 = [9, 4, 5, 36, 47, 26, 10, 45, 87]
Output :
[9, 10, 4, 5]

The problem is, that this needs to be computed within the numba function and therefore eg sets can not be used.问题是,这需要在 numba 函数中计算,因此不能使用例如集合。 Do you have an idea?你有想法吗? My current code is very basic.我当前的代码非常基础。 I assume that there is room for improvement.我认为还有改进的余地。

@nb.njit
def intersection:
   result = []
   for element1 in lst1:
      for element2 in lst2:
         if element1 == element2:
            result.append(element1)
   ....

Since numba compiles and runs your code in machine code, your probably at the best for such a simple operation.由于 numba 以机器代码编译和运行您的代码,因此您可能最适合这种简单的操作。 I ran some benchmarks below我在下面运行了一些基准测试

@nb.njit
def loop_intersection(lst1, lst2):
    result = []
    for element1 in lst1:
        for element2 in lst2:
            if element1 == element2:
                result.append(element1)
    return result

@nb.njit
def set_intersect(lst1, lst2):
    return set(lst1).intersection(set(lst2))

Resuls结果

loop_intersection
40.4 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

set_intersect
42 µs ± 6.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

I played with this a bit to try and learn something, realizing that the answer has already been given.我玩了一下,尝试学习一些东西,意识到答案已经给出。 When I run the accepted answer I get a return value of [9, 10, 5, 4, 9].当我运行接受的答案时,我得到 [9, 10, 5, 4, 9] 的返回值。 I wasn't clear if the repeated 9 was acceptable or not.我不清楚重复的 9 是否可以接受。 Assuming it's OK, I ran a trial using list comprehension to see it made any difference.假设没问题,我使用列表理解进行了一次试验,看看它有什么不同。 My results:我的结果:

from numba import jit

def createLists():
    l1 = [15, 9, 10, 56, 23, 78, 5, 4, 9]
    l2 = [9, 4, 5, 36, 47, 26, 10, 45, 87]

@jit
def listComp():
    l1, l2 = createLists()
    return [i for i in l1 for j in l2 if i == j]

%timeit listComp() 5.84 microseconds +/- 10.5 nanoseconds %timeit listComp() 5.84 微秒 +/- 10.5 纳秒

Or if you can can use Numpy this code is even faster and removes the duplicate "9" and is much faster with the Numba signature.或者,如果您可以使用 Numpy,则此代码会更快并删除重复的“9”,并且使用 Numba 签名会更快。

import numpy as np
from numba import jit, int64

@jit(int64[:](int64[:], int64[:]))
def JitListComp(l1, l2):
    l3 = np.array([i for i in l1 for j in l2 if i == j])
    return np.unique(l3) # and i not in crossSec]

@jit
def CreateList():
    l1 = np.array([15, 9, 10, 56, 23, 78, 5, 4, 9])
    l2 = np.array([9, 4, 5, 36, 47, 26, 10, 45, 87])
    return JitListComp(l1, l2)

CreateList()
Out[39]: array([ 4,  5,  9, 10])

%timeit CreateList()
1.71 µs ± 10.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

You can use set operation for this:您可以为此使用设置操作:

def intersection(lst1, lst2): 
    return list(set(lst1) & set(lst2))

then simply call the function intersection(lst1,lst2) .然后简单地调用函数intersection(lst1,lst2) This will be the easiest way.这将是最简单的方法。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM