![](/img/trans.png)
[英]Sklearn Pipeline - How to inherit get_params in custom Transformer (not Estimator)
[英]Write custom transformer in sklearn which returns .predict of estimator in .transform
我們有一個定制的變壓器
class EstimatorTransformer(base.BaseEstimator, base.TransformerMixin):
def __init__(self, estimator):
self.estimator = estimator
def fit(self, X, y):
self = self.estimator.fit(X,y)
return self
def transform(self, X):
return self.estimator.predict(X)
並且有一個斷言語句
city_trans = EstimatorTransformer(city_est)
city_trans.fit(features,target)
assert ([r[0] for r in city_trans.transform(data[:5])]
== city_est.predict(data[:5]))
哪里
city_est
是我們可以通過的估計量。 我正在使用city_est = city_est = Ridge(alpha = 1)
但我在self = self.estimator.fit(X,y)
遇到錯誤。 我在這里可能做錯了。 我知道fit()
返回self
。 我應該如何使這個斷言起作用?
您在這一行中分配錯誤:
self = self.estimator.fit(X,y)
在這里,self是當前的類(EstimatorTransformer),您正在嘗試為其分配其他類。
您可以這樣寫:
def fit(self, X, y):
self.estimator.fit(X,y)
return self
它會工作。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.