I am trying to override the predict_proba method of a classifier class. The easiest approach as far as I have seen and if applicable is preprocessing the input to the base class method or postprocessing its output.
class RandomForestClassifierWrapper(RandomForestClassifier):
def predict_proba(self, X):
pre_process(X)
ret = super(RandomForestClassifierWrapper, self).predict_proba(X)
return post_process(ret)
However, what I want to do is copying a variable which is locally created in the base class method, processed and garbage-collected when the method returns. I am intending to process the intermediate result stored in this variable. Is there a straightforward way to do this without messing with the base class internals ?
There is no way to access local variables of a method from the outside. What you could do, since you have the code of the base classifier, is overwrite the predict_proba
method by copying the code from the base classifier and handling the local variables however you want.
Try overriding:
class RandomForestClassifierWrapper(RandomForestClassifier):
def predict_proba(self, X):
check_is_fitted(self, 'n_outputs_')
# Check data
X = check_array(X, dtype=DTYPE, accept_sparse="csr")
# Assign chunk of trees to jobs
n_jobs, n_trees, starts = _partition_estimators(self.n_estimators,
self.n_jobs)
# Parallel loop
all_proba = Parallel(n_jobs=n_jobs, verbose=self.verbose,
backend="threading")(
# do something with all_proba
return all_proba
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.