简体   繁体   中英

Overriding predict() in statsmodels GLM to use in sklearn context

to use the Poisson GLM Model of statsmodels in the context of sklearn, I'm trying to set up an own Model which inherits from GLM, BaseEstimator ans RegressorMixin. My goal is to do stuff like cross validation. This is my code:

import statsmodels.api as sm
from sklearn.base import BaseEstimator, RegressorMixin

class GLM_sklearn(sm.GLM, BaseEstimator, RegressorMixin):
    def __init__(self, X, y, family=sm.families.Poisson()):
        super().__init__(y, X, family=family)

    def fit(self, **kwargs):
        self.results_ = super().fit()

        self.coef_ = self.results_.params.values
        self.bse_ = self.results_.bse.values

        return self

    def predict(self, X, **kwargs):
        return self.results_.predict(X)

The fit method works fine but I have a problem with overriding the predict(). To predict I need the predict method of the results instance (GLMResultsWrapper). So I want to override the GLM.predict method (which has another functionality). As tried in the code I get the expected error:

predict_results = self.model.predict(self.params, exog, *args, **kwargs) TypeError: predict() takes 2 positional arguments but 3 were given

Is there a possibility to override the predict method "completely"?

Instead of inheriting from all three which may give issues like one parent class overwriting another's members you may want GLM_sklearn to own instances of sm.GLM and RegressorMixin and only inherit from BaseEstimator. Then you can implement fit and predict however you want without having to worry about the members of the parent classes.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM