繁体   English   中英

数据在scikit-learn转换器中不是持久性的

[英]Data not persistent in scikit-learn transformers

我想将其他数据传递给scikit-learn中的转换器:

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.ensemble import RandomForestClassifier

from sklearn.pipeline import Pipeline
import numpy as np
from sklearn.model_selection import GridSearchCV

class myTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, my_np_array):
        self.data = my_np_array
        print self.data

    def transform(self, X):
        return X

    def fit(self, X, y=None):
        return self

data = np.random.rand(20,20)
data2 = np.random.rand(6,6)
y = np.array([1, 2, 3, 1, 2, 3, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 3, 3, 3, 3])

pipe = Pipeline(steps=[('myt', myTransformer(data2)), ('randforest', RandomForestClassifier())])
params = {"randforest__n_estimators": [100, 1000]}
estimators = GridSearchCV(pipe, param_grid=params, verbose=True)
estimators.fit(data, y)

但是,在scikit学习管道中使用时,它似乎消失了

我从init方法中的打印中获取了None 我如何解决它?

发生这种情况是因为sklearn以非常特定的方式处理估算器。 通常,它将为诸如网格搜索之类的事情创建该类的新实例,并将参数传递给构造函数。 发生这种情况是因为sklearn拥有自己的克隆操作( 在base.py中定义 ),该操作会接收您的estimator类,获取参数(由get_params返回)并将其传递给您的类的构造函数

klass = estimator.__class__
new_object_params = estimator.get_params(deep=False)
for name, param in six.iteritems(new_object_params):
    new_object_params[name] = clone(param, safe=False)
new_object = klass(**new_object_params) 

为了支持您的对象必须重写get_params(deep=False)方法,该方法应返回字典,并将其传递给构造函数

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
class myTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, my_np_array):
        self.data = my_np_array
        print self.data

    def transform(self, X):
        return X

    def fit(self, X, y=None):
        return self

    def get_params(self, deep=False):
        return {'my_np_array': self.data}

将按预期工作。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM