[英]cython cdef class c method: how to call it from another cython cdef class without python overhead?
我正在尝试在cython中实现通用排序算法。 因此,我创建了以下模块,该模块在类sorter_t
内部实现了Heapsort算法:
# file general_sort_c.pyx
from libc.stdint cimport int32_t
cdef bint bint_true = 1
cdef bint bint_false = 0
cdef class sorter_t:
cdef object sortable_object
def __init__(self,sortable_object):
self.sortable_object = sortable_object
cpdef sort_c(self):
"""
https://en.wikipedia.org/wiki/Heapsort
"""
cdef int32_t end
cdef int32_t count = self.sortable_object.num_elements_int32
self.heapify_c(count)
end = count-1
while end > 0:
self.sortable_object.swap_c(0,end)
end = end - 1
self.siftDown_c(0,end)
cdef heapify_c(self,int32_t count):
cdef int32_t start = (count - 2)/2
while start >= 0:
self.siftDown_c(start, count-1)
start -= 1
cdef siftDown_c(self,int32_t start, int32_t end):
cdef int32_t root = start
cdef int32_t swap
cdef int32_t child
while root * 2 + 1 <= end:
child = root * 2 + 1
swap = root
# if "swap" < "child" then ...
if self.sortable_object.lt_c(swap,child) == 1:
swap = child
if child+1 <= end and self.sortable_object.lt_c(swap,child+1) == 1:
swap = child + 1
if swap != root:
self.sortable_object.swap_c(root,swap)
root = swap
else:
return
当定义类型为sorter_t
的对象时,必须提供一个sortable_object
,它具有cdef函数lt_c
(用于比较一个元素是否小于另一个元素)和swap_c
(用于交换元素)的特定实现。
例如,以下代码将从列表中定义并创建sortable_object
,并使用该sortable_object
测试“ sorter_t”的实现。
import numpy
cimport numpy
from libc.stdint cimport int32_t
import general_sort_c
cdef class sortable_t:
cdef public int32_t num_elements_int32
cdef int32_t [:] mv_lista
def __init__(self,int32_t [:] mv_lista):
self.num_elements_int32 = mv_lista.shape[0]
self.mv_lista = mv_lista
cdef public bint lt_c(self, int32_t left, int32_t right):
if self.mv_lista[left] < self.mv_lista[right]:
return 1 # True
else:
return 0 # False
cdef public bint gt_c(self, int32_t left, int32_t right):
if self.mv_lista[left] > self.mv_lista[right]:
return 1 # True
else:
return 0 # False
cdef public swap_c(self, int32_t left, int32_t right):
cdef int32_t tmp
tmp = self.mv_lista[right]
self.mv_lista[right] = self.mv_lista[left]
self.mv_lista[left] = tmp
def probar():
lista = numpy.array([3,4,1,7],dtype=numpy.int32)
cdef int32_t [:] mv_lista = lista
cdef sortable = sortable_t(mv_lista)
cdef sorter = general_sort_c.sorter_t(sortable)
sorter.sort_increasing_c()
print list(lista)
编译两个.pyx
文件并在IPython控制台中运行以下命令后,出现以下错误:
In [1]: import test_general_sort_c as tgs
In [2]: tgs.probar()
...
general_sort_c.sorter_t.siftDown_increasing_c (general_sort_c.c:1452)()
132
133 #if mv_tnet_time[swap] < mv_tnet_time[child]:
--> 134 if self.sortable_object.lt_c(swap,child) == bint_true:
135 swap = child
136
AttributeError: 'test_general_sort_c.sortable_t' object has no attribute 'lt_c'
因此,问题在于从模块general_sort_c.pyx
的代码lt_c
不到函数lt_c
的实现。 如果我使用cpdef
而不是cdef
定义函数lt_c
,它将起作用,但是您将有很多Python开销。 如何以cdef
(“纯C”)方式调用此函数?
不幸的是,我不确定如何使它与融合类型一起使用,但是其余的很简单:
test_general_sort_c.pyx
需要一个免费的test_general_sort_c.pxd
:
from libc.stdint cimport int32_t
cdef class sortable_t:
cdef public int32_t num_elements_int32
cdef int32_t [:] mv_lista
cdef public bint lt_c(self, int32_t left, int32_t right)
cdef public bint gt_c(self, int32_t left, int32_t right)
cdef public swap_c(self, int32_t left, int32_t right)
general_sort_c.pyx
然后必须cimport
test_general_sort_c
并将其self.sortable_object
键入为test_general_sort_c.sortable_t
。
当然,如果您可以使用多种受支持的类型,那就更好了。 不过,目前还不确定您会怎么做。
此外,内置的True
和False
本身也可以正常工作。
如果您对Cython的信任度更高,您会意识到您可以编写
cdef public bint lt_c(self, int32_t left, int32_t right):
return self.mv_lista[left] < self.mv_lista[right]
cdef public bint gt_c(self, int32_t left, int32_t right):
return self.mv_lista[left] > self.mv_lista[right]
cdef public swap_c(self, int32_t left, int32_t right):
self.mv_lista[right], self.mv_lista[left] = self.mv_lista[left], self.mv_lista[right]
正好。 :)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.