[英]Casting Cython fused types to C++ pointers
這是關於將Cython融合類型轉換為C ++類型的一般問題,我將通過一個最小的示例進行介紹。 考慮一下表面的C ++函數模板:
template <typename T>
void scale_impl(const T * x, T * y, const T a, const size_t N) {
for (size_t n = 0; n < N; ++n) {
y[n] = a*x[n];
}
}
我希望能夠在任何類型和形狀的任何numpy ndarray
上調用此函數。 首先使用Cython聲明函數模板:
cdef extern:
void scale_impl[T](const T * x, T * y, const T a, const size_t N)
然后聲明我們要操作的有效標量類型:
ctypedef fused Scalar:
float
double
...
最后實現實際的Cython填充:
def scale(ndarray[Scalar] x, Scalar a):
"""Scale an array x by the value a"""
cdef ndarray[Scalar] y = np.empty_like(x)
scale_impl(<Scalar *>x.data, <Scalar *>y.data, a, x.size)
return y
這不起作用有兩個原因:
x
只能是一維的,不能是任意(或至少很多)維的 <Scalar *>
會引發錯誤,因為Scalar
實際上是一個Python對象 一個人顯然可以明確地推斷出這些專業:
if Scalar is float:
scale_impl(<float *>x.data, <float *>y.data, a, x.size)
if Scalar is double:
scale_impl(<double *>x.data, <double *>y.data, a, x.size)
if Scalar is ...
但是,這導致我不得不為娛樂多個融合類型的函數手寫一些代碼路徑,並產生了引入這種情況(我認為)是為了避免這種情況。
有什么方法可以將任意維數數組(在合理范圍內)傳遞給Cython函數,並推導出標量數據的指針類型? 或者,實現這種功能的最合理的折衷是什么?
(另請參見使用Cython包裝c ++模板以接受任何numpy數組中給出的答案,這是一個非常相似的問題。)
使用格式&x[0]
而不是嘗試x.data
解決了選擇正確的模板專業化的問題。 二維數組的問題更加復雜,因為不能保證數組是連續的或有序的。
我將創建一個在1D數組上執行實際工作的函數,並包裝一個根據需要變平的簡單函數:
def _scale_impl(Scalar[::1] x, Scalar a):
# the "::1" syntax ensures the array is actually continuous
cdef np.ndarray[Scalar,ndim=1] y = np.empty_like(x)
cdef size_t N = x.shape[0] # this seems to be necessary to avoid throwing off Cython's template deduction
scale_impl(&x[0],&y[0],a,N)
return y
def scale(x, a):
"""Scale an array x by the value a"""
y = _scale_impl(np.ravel(x),a)
return y.reshape(x.shape) # reshape needs to be kept out of Cython
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.