简体   繁体   English

无法加载腌制的自定义估算器 sklearn 管道

[英]Unable to load pickled custom estimator sklearn pipeline

I have a sklearn pipeline that uses custom column transformer, estimator and different lambda functions.我有一个 sklearn 管道,它使用自定义列转换器、估计器和不同的 lambda 函数。

Because Pickle cannot serialize the lambda functions, I am using dill.因为 Pickle 无法序列化 lambda 函数,所以我使用的是 dill。

Here is the custom estimator I have:这是我拥有的自定义估算器:

class customOLS(BaseEstimator):
    def __init__(self, ols):
        self.estimator_ols = ols

    def fit(self, X, y):
        X = pd.DataFrame(X)
        y = pd.DataFrame(y)
        print('---- Training OLS')
        self.estimator_ols = self.estimator_ols(y,X).fit()
        #print('---- Training LR')
        #self.estimator_lr = self.estimator_lr.fit(X,y)
        return self

    def get_estimators(self):
        return self.estimator_ols #, self.estimator_lr
                
    def predict_ols(self, X):
        res = self.estimator_ols.predict(X)
        return res

pipeline2 = Pipeline(
        steps=[
            ('dropper', drop_cols),
            ('remover',feature_remover),
            ("preprocessor", preprocess_ppl),
            ("estimator", customOLS(sm.OLS))
            ]
    )

This is how I serilize it (I have to use open() otherwise it gives unsupportedOperation read write):这就是我对其进行序列化的方式(我必须使用 open() 否则它会给出 unsupportedOperation 读写):

with open('data/baseModel_LR.joblib',"wb") as f:
        dill.dump(pipeline2, f)

But when I try to load the pickled object:但是当我尝试加载腌制的 object 时:

with open('data/baseModel_LR.joblib',"rb") as f:
        model = dill.load(f)
model

I get this error related to custom estimator:我收到与自定义估算器相关的此错误:

AttributeError: 'customOLS' object has no attribute 'ols' AttributeError: 'customOLS' object 没有属性 'ols'

在此处输入图像描述

The problem lies on these two lines:问题出在这两行:

    def __init__(self, ols):
        self.estimator_ols = ols

Here's an excerpt from the sklearn documentation, which explains why this won't work:这是 sklearn 文档的摘录,它解释了为什么这不起作用:

All scikit-learn estimators have get_params and set_params functions.所有 scikit-learn 估计器都有get_paramsset_params函数。 The get_params function takes no arguments and returns a dict of the __init__ parameters of the estimator, together with their values. get_params function 不接受 arguments 并返回估计器的__init__参数及其值的字典。

Source . 来源

So, if you have a parameter named ols in your constructor, sklearn assumes that you have an attribute on your object, called ols .因此,如果您的构造函数中有一个名为ols的参数,sklearn 会假定您在 object 上有一个名为ols的属性。 When you call get_params() on your object, (and repr() does call that) then that extracts the name of each variable from the constructor.当您在 object 上调用get_params()时(并且repr()确实调用了它),那么它会从构造函数中提取每个变量的名称。

To fix it, change the constructor to this:要修复它,请将构造函数更改为:

    def __init__(self, estimator_ols):
        self.estimator_ols = estimator_ols

When I do that, I am able to save and load the pipeline.当我这样做时,我可以保存和加载管道。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM