簡體   English   中英

如何正確地將numpy邏輯函數傳遞給Cython?

[英]How to pass numpy logic functions to Cython correctly?

我應該將哪些聲明與邏輯函數/索引操作合並在一起,以便Cython能夠輕松完成任務?

我有兩個大小相等的numpy數組形式的大柵格。 第一個數組包含植被索引值,第二個數組包含字段ID。 目標是按田地平均植被指數值。 這兩個數組都有討厭的nodata值(-9999),我想忽略它們。

目前,該函數需要60秒鍾才能執行,通常我不會介意那么多,但我將處理數百個圖像。 甚至30秒的改善也是很重要的。 因此,我一直在探索Cython,以幫助加快運行速度。 我一直在使用Cython numpy教程作為指南。

示例數據

test_cy.pyx代碼:

import numpy as np
cimport numpy as np
cimport cython
@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function 

cpdef test():
  cdef np.ndarray[np.int16_t, ndim=2] ndvi_array = np.load("Z:cython_test/data/ndvi.npy")

  cdef np.ndarray[np.int16_t, ndim=2] field_array = np.load("Z:cython_test/data/field_array.npy")

  cdef np.ndarray[np.int16_t, ndim=1] unique_field = np.unique(field_array)
  unique_field = unique_field[unique_field != -9999]

  cdef int field_id
  cdef np.ndarray[np.int16_t, ndim=1] f_ndvi_values
  cdef double f_avg

  for field_id in unique_field :
      f_ndvi_values = ndvi_array[np.logical_and(field_array == field_id, ndvi_array != -9999)]
      f_avg = np.mean(f_ndvi_values)

Setup.py代碼:

try:
    from setuptools import setup
    from setuptools import Extension
except ImportError:
    from distutils.core import setup
    from distutils.extension import Extension

from Cython.Build import cythonize
import numpy

setup(ext_modules = cythonize('test_cy.pyx'),
      include_dirs=[numpy.get_include()])

經過研究和運行:

cython -a test_cy.pyx

看來索引操作ndvi_array[np.logical_and(field_array == field_id, ndvi_array != -9999)]是瓶頸,仍然依賴於Python。 我懷疑我在這里缺少一些重要的聲明。 包括ndim並沒有任何效果。

我對numpy也相當陌生,所以我可能缺少明顯的東西。

您的問題對我來說似乎可以解決,因此Cython可能不是最好的方法。 (當出現不可避免的細粒度循環時,Cython會發光。)由於int16int16 ,因此可能的標簽范圍非常有限,因此使用np.bincount應該相當有效。 嘗試類似的操作(這是假設所有有效值都> = 0,如果不是這種情況,則不必將其移位-或(廉價)視圖轉換為uint16 (因為我們沒有對應該安全)-在使用bincount之前):

mask = (ndvi_array != -9999) & (field_array != -9999)
nd = ndvi_array[mask]
fi = field_array[mask]
counts = np.bincount(fi, minlength=2**15)
sums = np.bincount(fi, nd, minlength=2**15)
valid = counts != 0
avgs = sums[valid] / counts[valid]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM