繁体   English   中英

覆盖要在sklearn上下文中使用的statsmodels GLM中的predict()

[英]Overriding predict() in statsmodels GLM to use in sklearn context

为了在sklearn的上下文中使用statsmodels的Poisson GLM模型,我试图建立一个自己的模型,该模型继承自GLM BaseEstimator an RegressorMixin。 我的目标是做类似交叉验证的工作。 这是我的代码:

import statsmodels.api as sm
from sklearn.base import BaseEstimator, RegressorMixin

class GLM_sklearn(sm.GLM, BaseEstimator, RegressorMixin):
    def __init__(self, X, y, family=sm.families.Poisson()):
        super().__init__(y, X, family=family)

    def fit(self, **kwargs):
        self.results_ = super().fit()

        self.coef_ = self.results_.params.values
        self.bse_ = self.results_.bse.values

        return self

    def predict(self, X, **kwargs):
        return self.results_.predict(X)

适合的方法工作正常,但我有一个覆盖predict()的问题。 要进行预测,我需要结果实例的预测方法(GLMResultsWrapper)。 因此,我想重写GLM.predict方法(具有另一个功能)。 如代码中所试,我得到了预期的错误:

预测结果= self.model.predict(self.params,exog,* args,** kwargs)TypeError:预测()接受2个位置参数,但给出了3个

是否有可能“完全”覆盖预测方法?

您可能希望GLM_sklearn拥有sm.GLM和RegressorMixin的实例,并且仅从BaseEstimator继承,而不是从所有这三个类中继承,这可能会给一个父类重写另一个成员带来麻烦。 然后,您可以实现拟合并根据需要进行预测,而不必担心父类的成员。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM