简体   繁体   中英

Scikit-learn: overriding a class method in a classifier

I am trying to override the predict_proba method of a classifier class. The easiest approach as far as I have seen and if applicable is preprocessing the input to the base class method or postprocessing its output.

class RandomForestClassifierWrapper(RandomForestClassifier):

    def predict_proba(self, X):
        pre_process(X)
        ret = super(RandomForestClassifierWrapper, self).predict_proba(X)
        return post_process(ret)

However, what I want to do is copying a variable which is locally created in the base class method, processed and garbage-collected when the method returns. I am intending to process the intermediate result stored in this variable. Is there a straightforward way to do this without messing with the base class internals ?

There is no way to access local variables of a method from the outside. What you could do, since you have the code of the base classifier, is overwrite the predict_proba method by copying the code from the base classifier and handling the local variables however you want.

Try overriding:

class RandomForestClassifierWrapper(RandomForestClassifier):

    def predict_proba(self, X):
            check_is_fitted(self, 'n_outputs_')

            # Check data
            X = check_array(X, dtype=DTYPE, accept_sparse="csr")

            # Assign chunk of trees to jobs
            n_jobs, n_trees, starts = _partition_estimators(self.n_estimators,
                                                            self.n_jobs)

            # Parallel loop
            all_proba = Parallel(n_jobs=n_jobs, verbose=self.verbose,
                                 backend="threading")(

            # do something with all_proba

            return all_proba

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM