簡體   English   中英

覆蓋要在sklearn上下文中使用的statsmodels GLM中的predict()

[英]Overriding predict() in statsmodels GLM to use in sklearn context

為了在sklearn的上下文中使用statsmodels的Poisson GLM模型,我試圖建立一個自己的模型,該模型繼承自GLM BaseEstimator an RegressorMixin。 我的目標是做類似交叉驗證的工作。 這是我的代碼:

import statsmodels.api as sm
from sklearn.base import BaseEstimator, RegressorMixin

class GLM_sklearn(sm.GLM, BaseEstimator, RegressorMixin):
    def __init__(self, X, y, family=sm.families.Poisson()):
        super().__init__(y, X, family=family)

    def fit(self, **kwargs):
        self.results_ = super().fit()

        self.coef_ = self.results_.params.values
        self.bse_ = self.results_.bse.values

        return self

    def predict(self, X, **kwargs):
        return self.results_.predict(X)

適合的方法工作正常,但我有一個覆蓋predict()的問題。 要進行預測,我需要結果實例的預測方法(GLMResultsWrapper)。 因此,我想重寫GLM.predict方法(具有另一個功能)。 如代碼中所試,我得到了預期的錯誤:

預測結果= self.model.predict(self.params,exog,* args,** kwargs)TypeError:預測()接受2個位置參數,但給出了3個

是否有可能“完全”覆蓋預測方法?

您可能希望GLM_sklearn擁有sm.GLM和RegressorMixin的實例,並且僅從BaseEstimator繼承,而不是從所有這三個類中繼承,這可能會給一個父類重寫另一個成員帶來麻煩。 然后,您可以實現擬合並根據需要進行預測,而不必擔心父類的成員。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM