[英]Optimize Jaccard distance performance with numba
我正在尝试使用Numba在python中实现尽可能快的jaccard距离版本
@nb.jit()
def nbjaccard(seq1, seq2):
set1, set2 = set(seq1), set(seq2)
return 1 - len(set1 & set2) / float(len(set1 | set2))
def jaccard(seq1, seq2):
set1, set2 = set(seq1), set(seq2)
return 1 - len(set1 & set2) / float(len(set1 | set2))
%%timeit
nbjaccard("compare this string","compare a different string")
--12.4毫秒
%%timeit
jaccard("compare this string","compare a different string")
--3.87毫秒
为什么numba版本需要更长时间? 有什么方法可以获得加速?
在我看来,允许纯对象模式numba函数(或者如果numba意识到整个函数使用python对象没有警告)是一个设计错误 - 因为这些通常比纯python函数慢一点。
Numba非常强大(类型调度,你可以编写没有类型声明的python代码 - 与C扩展或Cython相比 - 真的很棒)但只有当它支持操作时:
这意味着“nopython”模式不支持任何未列在其中的操作。 如果numba必须回到“对象模式”,那么要小心:
对象模式
Numba编译模式,生成代码,将所有值作为Python对象处理,并使用Python C API对这些对象执行所有操作。 除非Numba编译器可以利用循环匹配,否则在对象模式下编译的代码通常不会比Python解释代码快。
而这恰恰是你的情况:你纯粹在对象模式下运作:
>>> nbjaccard.inspect_types()
[...]
# --- LINE 3 ---
# seq1 = arg(0, name=seq1) :: pyobject
# seq2 = arg(1, name=seq2) :: pyobject
# $0.1 = global(set: <class 'set'>) :: pyobject
# $0.3 = call $0.1(seq1) :: pyobject
# $0.4 = global(set: <class 'set'>) :: pyobject
# $0.6 = call $0.4(seq2) :: pyobject
# set1 = $0.3 :: pyobject
# set2 = $0.6 :: pyobject
set1, set2 = set(seq1), set(seq2)
# --- LINE 4 ---
# $const0.7 = const(int, 1) :: pyobject
# $0.8 = global(len: <built-in function len>) :: pyobject
# $0.11 = set1 & set2 :: pyobject
# $0.12 = call $0.8($0.11) :: pyobject
# $0.13 = global(float: <class 'float'>) :: pyobject
# $0.14 = global(len: <built-in function len>) :: pyobject
# $0.17 = set1 | set2 :: pyobject
# $0.18 = call $0.14($0.17) :: pyobject
# $0.19 = call $0.13($0.18) :: pyobject
# $0.20 = $0.12 / $0.19 :: pyobject
# $0.21 = $const0.7 - $0.20 :: pyobject
# $0.22 = cast(value=$0.21) :: pyobject
# return $0.22
return 1 - len(set1 & set2) / float(len(set1 | set2))
正如您所看到的,每个操作都在Python对象上运行(如每行末尾的:: pyobject
所示)。 那是因为numba
不支持str
和set
s。 所以绝对没有什么可以在这里更快。 除了你有一个想法如何使用numpy数组或同类列表(数值类型)解决这个问题。
在我的电脑上时间差异要大得多(使用numba 0.32.0),但个人时间要快得多 - 微秒 ( 10**-6
秒)而不是毫秒 ( 10**-3
秒):
%timeit nbjaccard("compare this string","compare a different string")
10000 loops, best of 3: 84.4 µs per loop
%timeit jaccard("compare this string","compare a different string")
100000 loops, best of 3: 15.9 µs per loop
请注意,默认情况下jit
是惰性的 ,因此第一次调用应该在执行时间之前完成 - 因为它包含编译代码的时间。
然而,你可以做一个优化:如果你知道两个集合的交集,你可以计算联合的长度(正如@Paul Hankin在他现在删除的答案中提到的):
len(union) = len(set1) + len(set2) - len(intersection)
这将导致以下(纯python)代码:
def jaccard2(seq1, seq2):
set1, set2 = set(seq1), set(seq2)
num_intersection = len(set1 & set2)
return 1 - num_intersection / float(len(set1) + len(set2) - num_intersection)
%timeit jaccard2("compare this string","compare a different string")
100000 loops, best of 3: 13.7 µs per loop
不是更快 - 但更好。
%load_ext cython
%%cython
def cyjaccard(seq1, seq2):
cdef set set1 = set(seq1)
cdef set set2 = set()
cdef Py_ssize_t length_intersect = 0
for char in seq2:
if char not in set2:
if char in set1:
length_intersect += 1
set2.add(char)
return 1 - (length_intersect / float(len(set1) + len(set2) - length_intersect))
%timeit cyjaccard("compare this string","compare a different string")
100000 loops, best of 3: 7.97 µs per loop
这里的主要优点是只需一次迭代就可以创建set2
并计算交集中的元素数量(根本不需要创建交集)!
当我计算这两个函数时, nbjaccard
需要大约4.7微秒(在预热jit之后)和普通python使用Numba 0.32.0需要大约3.2微秒。 也就是说,我不认为numba会在这种情况下为你提供任何加速,因为目前在nopython
模式下基本上没有字符串支持。 这意味着你要经历python对象层,这通常与没有jit运行没什么不同,除非numba可以做一些智能循环提升(即使用纯内在函数而不是python函数编译子块)。 除了在numba情况下检查输入的类型之外,你可能只需要支付一些小的开销。
我认为最重要的是你试图将numba用于目前尚未涵盖的用例。 Numba真正擅长的是处理numpy数组和数值标量值或可以推送到GPU的问题的操作。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.