繁体   English   中英

在其fit方法带有3个参数的自定义类上使用sklearn GridSearchCV

[英]Use sklearn GridSearchCV on custom class whose fit method takes 3 arguments

我正在从事一个涉及将某些算法实现为python类并测试其性能的项目。 我决定将它们编写为sklearn估计器,以便可以使用GridSearchCV进行验证。

但是,我的归纳矩阵补全算法之一不仅将Xy作为参数。 这对于GridSearchCV.fit是一个问题,因为似乎没有办法将Xy传递给估计量的fit方法。 源显示GridSearchCV.fit的以下参数:

def fit(self, X, y=None, groups=None, **fit_params):

当然,下游方法只期望这两个参数。 显然,修改GridSearchCV本地副本以适应我的需求绝非易事(或不建议这样做)。

作为参考,IMC基本上声明$ R \\大约XW ^ THY ^ T $。 因此,我的fit方法采用以下形式:

def fit(self, R, X, Y):

因此尝试以下操作失败,因为Y值永远不会传递给IMC.fit方法:

imc = IMC()
params = {...}
gs = GridSearchCV(imc, param_grid=params)
gs.fit(R, X, Y)

我已经通过修改IMC.fit方法(为此也必须将其插入到score方法中)创建了解决方法:

def fit(self, R, X, Y=None):
    if Y is None:
        split = np.where(np.all(X == 999, axis=0))[0][0]
        Y = X[:, split + 1:]
        X = X[:, :split]
    ...

这使我可以使用numpy.hstack来水平堆叠X和Y,并在它们之间插入所有999的列。 然后可以将该数组传递给GridSearchCV.fit ,如下所示:

data = np.hstack([X, np.ones((X.shape[0],1)) * 999, Y])
gs.fit(R, data)

这种方法有效,但感觉很棘手。 因此,我的问题是这样的:

是否存在使用GridSearchCV将两个以上的参数传递给fit方法的公认方法或最佳实践?

因此,从这个朋友( @Matthew Drury )那里得到一些启发之后,我构建了一个更为优雅的解决方案。

再次将问题归结为:

我有一个矩阵完成方法,该方法XY ,和R作为参数,并试图构建WH最小化R - XWHY在所有观察到的指数R fit方法的基本实现如下所示:

def fit(X, Y, R):
    W, H = do_minimization(X, Y, R)
    return W, H

这与标准的sklearn模型不太吻合,在标准的sklearn模型中,拟合需要一个X (输入模型的要素)和y (结果),看起来像这样:

def fit(X, y):
    W, H = do_minimization(X, y)
    return W, H

在您开始使用GridSearchCV或其他交叉验证方法之前,这并不是真正的问题,因为他们希望数据适合后一种格式。 因此,要将这两个概念结合起来,我需要一种将两个不同的矩阵XY打包到一个结构中的方法,而又不会失去二者的独立性。

在最初的5分钟中,我不得不致力于这一点,因此我想出了解决方案。 在矩阵R形状n, m中,行对应于X的记录,列对应于Y的记录,总共有b个条目。 如果我们为所有这些条目获取行索引和列索引,并在行上索引X ,在列上索引Y ,那么对于XY将得到等长的矩阵。 然后可以将它们水平堆叠,用一堆废话隔开,然后毫无问题地传递给交叉验证方法(我们只需要在原始类内部使用几个辅助方法,即可在拟合之前从堆栈中重建原始XY

这个问题的重点是找到优雅的解决方案,或者最好是现有的解决方案。 似乎并非如此,因此我将为从sklearn继承而构建的任何未来估计器/分类器提出以下模型,这些估计器/分类器不仅仅需要fit方法的单个特征矩阵。

创建一个数据处理程序

使用GridSearchCVfit方法会进行一轮检查,然后触发对估算器fit方法的所有调用。 其中之一确定传递的X数组是否可索引 该测试基本上检查X实现__getitem__iloc并且长度与y相同。 此长度检查要求X具有shape属性。 那时,分裂指数和拟合度可以按预期计算。 因此,我们需要一个实现__getitem__并具有shape属性的包装器。

class DataHandler(object):

    def __init(self, X, Y):
        self.X = X
        self.Y = Y
        self.shape = self.X.shape

    def __getitem__(self, x):
        return self.X[x], self.Y[x]

而已! 现在,我们可以通过修改fit方法来匹配sklearn风格,但在这种情况下,而不是X是一个数组,它要么是一个元组(由返回的结果__getitem__方法)或我们的一个实例DataHandler类。

现在,仅通过传递包含XY数组的DataHandler实例, GridSearchCV按预期工作。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM