简体   繁体   中英

Cython - Indexing numpy array within nogil function

I'm trying to extend the Splitter class in sklearn, which works with sklearn's decision tree classes. More specifically, I want to add a feature_weights variable in the new class, which will affect the determination of the best split point by altering the purity calculations proportionally with the feature weights.

The new class is almost an exact copy of sklearn's BestSplitter class ( https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_splitter.pyx ) with only minor changes. Here's what I have so far:

cdef class WeightedBestSplitter(WeightedBaseDenseSplitter):

    cdef object feature_weights # new variable - 1D array of feature weights

    def __reduce__(self):
        # same as sklearn BestSplitter (basically)

    # NEW METHOD
    def set_weights(self, object feature_weights): 
        feature_weights = np.asfortranarray(feature_weights, dtype=DTYPE)
        self.feature_weights = feature_weights  

    cdef int node_split(self, double impurity, SplitRecord* split,
                        SIZE_t* n_constant_features) nogil except -1:

        # .... same as sklearn BestSplitter ....

        current_proxy_improvement = self.criterion.proxy_impurity_improvement()
        current_proxy_improvement *= self.feature_weights[<int>(current.feature)]  # new line

        # .... same as sklearn BestSplitter ....

A couple notes about the above: I'm using the object variable type and np.asfortranarray because that is how the variable X is defined and set in other places and X is indexed like I'm trying to index feature_weights . Also, custom.feature has a variable type of SIZE_t per the _splitter.pxd file ( https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_splitter.pxd ).

The issue seems to be created by the variable type of self.feature_weights . The above code throws multiple errors, but even trying to reference something like self.feature_weights[0] and set it to another variable throws the error:

Indexing Python object not allowed without gil

I'm wondering what I need to do to be able to index self.feature_weights and use the scalar value as a multiplier.

You definitely can't index a generic Python object without the GIL (as you're trying to do). You can index typed memoyviews without the GIL.

Define feature_weights as

cdef double[:] feature_weights

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM