I'm trying to extend the Splitter
class in sklearn, which works with sklearn's decision tree classes. More specifically, I want to add a feature_weights
variable in the new class, which will affect the determination of the best split point by altering the purity calculations proportionally with the feature weights.
The new class is almost an exact copy of sklearn's BestSplitter
class ( https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_splitter.pyx ) with only minor changes. Here's what I have so far:
cdef class WeightedBestSplitter(WeightedBaseDenseSplitter):
cdef object feature_weights # new variable - 1D array of feature weights
def __reduce__(self):
# same as sklearn BestSplitter (basically)
# NEW METHOD
def set_weights(self, object feature_weights):
feature_weights = np.asfortranarray(feature_weights, dtype=DTYPE)
self.feature_weights = feature_weights
cdef int node_split(self, double impurity, SplitRecord* split,
SIZE_t* n_constant_features) nogil except -1:
# .... same as sklearn BestSplitter ....
current_proxy_improvement = self.criterion.proxy_impurity_improvement()
current_proxy_improvement *= self.feature_weights[<int>(current.feature)] # new line
# .... same as sklearn BestSplitter ....
A couple notes about the above: I'm using the object
variable type and np.asfortranarray
because that is how the variable X
is defined and set in other places and X
is indexed like I'm trying to index feature_weights
. Also, custom.feature
has a variable type of SIZE_t
per the _splitter.pxd
file ( https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_splitter.pxd ).
The issue seems to be created by the variable type of self.feature_weights
. The above code throws multiple errors, but even trying to reference something like self.feature_weights[0]
and set it to another variable throws the error:
Indexing Python object not allowed without gil
I'm wondering what I need to do to be able to index self.feature_weights
and use the scalar value as a multiplier.
You definitely can't index a generic Python object without the GIL (as you're trying to do). You can index typed memoyviews without the GIL.
Define feature_weights
as
cdef double[:] feature_weights
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.