[英]Unable to load pickled custom estimator sklearn pipeline
I have a sklearn pipeline that uses custom column transformer, estimator and different lambda functions.我有一个 sklearn 管道,它使用自定义列转换器、估计器和不同的 lambda 函数。
Because Pickle cannot serialize the lambda functions, I am using dill.因为 Pickle 无法序列化 lambda 函数,所以我使用的是 dill。
Here is the custom estimator I have:这是我拥有的自定义估算器:
class customOLS(BaseEstimator):
def __init__(self, ols):
self.estimator_ols = ols
def fit(self, X, y):
X = pd.DataFrame(X)
y = pd.DataFrame(y)
print('---- Training OLS')
self.estimator_ols = self.estimator_ols(y,X).fit()
#print('---- Training LR')
#self.estimator_lr = self.estimator_lr.fit(X,y)
return self
def get_estimators(self):
return self.estimator_ols #, self.estimator_lr
def predict_ols(self, X):
res = self.estimator_ols.predict(X)
return res
pipeline2 = Pipeline(
steps=[
('dropper', drop_cols),
('remover',feature_remover),
("preprocessor", preprocess_ppl),
("estimator", customOLS(sm.OLS))
]
)
This is how I serilize it (I have to use open() otherwise it gives unsupportedOperation read write):这就是我对其进行序列化的方式(我必须使用 open() 否则它会给出 unsupportedOperation 读写):
with open('data/baseModel_LR.joblib',"wb") as f:
dill.dump(pipeline2, f)
But when I try to load the pickled object:但是当我尝试加载腌制的 object 时:
with open('data/baseModel_LR.joblib',"rb") as f:
model = dill.load(f)
model
I get this error related to custom estimator:我收到与自定义估算器相关的此错误:
AttributeError: 'customOLS' object has no attribute 'ols' AttributeError: 'customOLS' object 没有属性 'ols'
The problem lies on these two lines:问题出在这两行:
def __init__(self, ols):
self.estimator_ols = ols
Here's an excerpt from the sklearn documentation, which explains why this won't work:这是 sklearn 文档的摘录,它解释了为什么这不起作用:
All scikit-learn estimators have
get_params
andset_params
functions.所有 scikit-learn 估计器都有get_params
和set_params
函数。 Theget_params
function takes no arguments and returns a dict of the__init__
parameters of the estimator, together with their values.get_params
function 不接受 arguments 并返回估计器的__init__
参数及其值的字典。
So, if you have a parameter named ols
in your constructor, sklearn assumes that you have an attribute on your object, called ols
.因此,如果您的构造函数中有一个名为ols
的参数,sklearn 会假定您在 object 上有一个名为ols
的属性。 When you call get_params()
on your object, (and repr()
does call that) then that extracts the name of each variable from the constructor.当您在 object 上调用get_params()
时(并且repr()
确实调用了它),那么它会从构造函数中提取每个变量的名称。
To fix it, change the constructor to this:要修复它,请将构造函数更改为:
def __init__(self, estimator_ols):
self.estimator_ols = estimator_ols
When I do that, I am able to save and load the pipeline.当我这样做时,我可以保存和加载管道。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.