繁体   English   中英

使用numba优化Jaccard距离性能

[英]Optimize Jaccard distance performance with numba

我正在尝试使用Numba在python中实现尽可能快的jaccard距离版本

@nb.jit()
def nbjaccard(seq1, seq2):
    set1, set2 = set(seq1), set(seq2)
    return 1 - len(set1 & set2) / float(len(set1 | set2))

def jaccard(seq1, seq2):
    set1, set2 = set(seq1), set(seq2)
    return 1 - len(set1 & set2) / float(len(set1 | set2))


%%timeit
nbjaccard("compare this string","compare a different string")

--12.4毫秒

%%timeit 
jaccard("compare this string","compare a different string")

--3.87毫秒

为什么numba版本需要更长时间? 有什么方法可以获得加速?

在我看来,允许对象模式numba函数(或者如果numba意识到整个函数使用python对象没有警告)是一个设计错误 - 因为这些通常比纯python函数慢一点。

Numba非常强大(类型调度,你可以编写没有类型声明的python代码 - 与C扩展或Cython相比 - 真的很棒)但只有当它支持操作时:

这意味着“nopython”模式不支持任何未列在其中的操作。 如果numba必须回到“对象模式”,那么要小心:

对象模式

Numba编译模式,生成代码,将所有值作为Python对象处理,并使用Python C API对这些对象执行所有操作。 除非Numba编译器可以利用循环匹配,否则在对象模式下编译的代码通常不会比Python解释代码快。

而这恰恰是你的情况:你纯粹在对象模式下运作:

>>> nbjaccard.inspect_types()

[...]
# --- LINE 3 --- 
#   seq1 = arg(0, name=seq1)  :: pyobject
#   seq2 = arg(1, name=seq2)  :: pyobject
#   $0.1 = global(set: <class 'set'>)  :: pyobject
#   $0.3 = call $0.1(seq1)  :: pyobject
#   $0.4 = global(set: <class 'set'>)  :: pyobject
#   $0.6 = call $0.4(seq2)  :: pyobject
#   set1 = $0.3  :: pyobject
#   set2 = $0.6  :: pyobject

set1, set2 = set(seq1), set(seq2)

# --- LINE 4 --- 
#   $const0.7 = const(int, 1)  :: pyobject
#   $0.8 = global(len: <built-in function len>)  :: pyobject
#   $0.11 = set1 & set2  :: pyobject
#   $0.12 = call $0.8($0.11)  :: pyobject
#   $0.13 = global(float: <class 'float'>)  :: pyobject
#   $0.14 = global(len: <built-in function len>)  :: pyobject
#   $0.17 = set1 | set2  :: pyobject
#   $0.18 = call $0.14($0.17)  :: pyobject
#   $0.19 = call $0.13($0.18)  :: pyobject
#   $0.20 = $0.12 / $0.19  :: pyobject
#   $0.21 = $const0.7 - $0.20  :: pyobject
#   $0.22 = cast(value=$0.21)  :: pyobject
#   return $0.22

return 1 - len(set1 & set2) / float(len(set1 | set2))

正如您所看到的,每个操作都在Python对象上运行(如每行末尾的:: pyobject所示)。 那是因为numba不支持strset s。 所以绝对没有什么可以在这里更快。 除了你有一个想法如何使用numpy数组或同类列表(数值类型)解决这个问题。

在我的电脑上时间差异要大得多(使用numba 0.32.0),但个人时间要快得多 - 微秒10**-6秒)而不是毫秒10**-3秒):

%timeit nbjaccard("compare this string","compare a different string")
10000 loops, best of 3: 84.4 µs per loop

%timeit jaccard("compare this string","compare a different string")
100000 loops, best of 3: 15.9 µs per loop

请注意,默认情况下jit惰性的 ,因此第一次调用应该在执行时间之前完成 - 因为它包含编译代码的时间。


然而,你可以做一个优化:如果你知道两个集合的交集,你可以计算联合的长度(正如@Paul Hankin在他现在删除的答案中提到的):

len(union) = len(set1) + len(set2) - len(intersection)

这将导致以下(纯python)代码:

def jaccard2(seq1, seq2):
    set1, set2 = set(seq1), set(seq2)
    num_intersection = len(set1 & set2)
    return 1 - num_intersection / float(len(set1) + len(set2) - num_intersection)

%timeit jaccard2("compare this string","compare a different string")
100000 loops, best of 3: 13.7 µs per loop

不是更快 - 但更好。


如果你使用 ,还有一些改进

%load_ext cython

%%cython
def cyjaccard(seq1, seq2):
    cdef set set1 = set(seq1)
    cdef set set2 = set()

    cdef Py_ssize_t length_intersect = 0

    for char in seq2:
        if char not in set2:
            if char in set1:
                length_intersect += 1
            set2.add(char)

    return 1 - (length_intersect / float(len(set1) + len(set2) - length_intersect))

%timeit cyjaccard("compare this string","compare a different string")
100000 loops, best of 3: 7.97 µs per loop

这里的主要优点是只需一次迭代就可以创建set2并计算交集中的元素数量(根本不需要创建交集)!

当我计算这两个函数时, nbjaccard需要大约4.7微秒(在预热jit之后)和普通python使用Numba 0.32.0需要大约3.2微秒。 也就是说,我不认为numba会在这种情况下为你提供任何加速,因为目前在nopython模式下基本上没有字符串支持。 这意味着你要经历python对象层,这通常与没有jit运行没什么不同,除非numba可以做一些智能循环提升(即使用纯内在函数而不是python函数编译子块)。 除了在numba情况下检查输入的类型之外,你可能只需要支付一些小的开销。

我认为最重要的是你试图将numba用于目前尚未涵盖的用例。 Numba真正擅长的是处理numpy数组和数值标量值或可以推送到GPU的问题的操作。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM