![](/img/trans.png)
[英]Sklearn Pipeline - How to inherit get_params in custom Transformer (not Estimator)
[英]Write custom transformer in sklearn which returns .predict of estimator in .transform
我们有一个定制的变压器
class EstimatorTransformer(base.BaseEstimator, base.TransformerMixin):
def __init__(self, estimator):
self.estimator = estimator
def fit(self, X, y):
self = self.estimator.fit(X,y)
return self
def transform(self, X):
return self.estimator.predict(X)
并且有一个断言语句
city_trans = EstimatorTransformer(city_est)
city_trans.fit(features,target)
assert ([r[0] for r in city_trans.transform(data[:5])]
== city_est.predict(data[:5]))
哪里
city_est
是我们可以通过的估计量。 我正在使用city_est = city_est = Ridge(alpha = 1)
但我在self = self.estimator.fit(X,y)
遇到错误。 我在这里可能做错了。 我知道fit()
返回self
。 我应该如何使这个断言起作用?
您在这一行中分配错误:
self = self.estimator.fit(X,y)
在这里,self是当前的类(EstimatorTransformer),您正在尝试为其分配其他类。
您可以这样写:
def fit(self, X, y):
self.estimator.fit(X,y)
return self
它会工作。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.