[英]Use sklearn GridSearchCV on custom class whose fit method takes 3 arguments
我正在从事一个涉及将某些算法实现为python类并测试其性能的项目。 我决定将它们编写为sklearn估计器,以便可以使用GridSearchCV
进行验证。
但是,我的归纳矩阵补全算法之一不仅将X
和y
作为参数。 这对于GridSearchCV.fit
是一个问题,因为似乎没有办法将X
和y
传递给估计量的fit方法。 源显示GridSearchCV.fit
的以下参数:
def fit(self, X, y=None, groups=None, **fit_params):
当然,下游方法只期望这两个参数。 显然,修改GridSearchCV
本地副本以适应我的需求绝非易事(或不建议这样做)。
作为参考,IMC基本上声明$ R \\大约XW ^ THY ^ T $。 因此,我的fit方法采用以下形式:
def fit(self, R, X, Y):
因此尝试以下操作失败,因为Y值永远不会传递给IMC.fit
方法:
imc = IMC()
params = {...}
gs = GridSearchCV(imc, param_grid=params)
gs.fit(R, X, Y)
我已经通过修改IMC.fit
方法(为此也必须将其插入到score
方法中)创建了解决方法:
def fit(self, R, X, Y=None):
if Y is None:
split = np.where(np.all(X == 999, axis=0))[0][0]
Y = X[:, split + 1:]
X = X[:, :split]
...
这使我可以使用numpy.hstack
来水平堆叠X和Y,并在它们之间插入所有999
的列。 然后可以将该数组传递给GridSearchCV.fit
,如下所示:
data = np.hstack([X, np.ones((X.shape[0],1)) * 999, Y])
gs.fit(R, data)
这种方法有效,但感觉很棘手。 因此,我的问题是这样的:
GridSearchCV
将两个以上的参数传递给fit方法的公认方法或最佳实践? 因此,从这个朋友( @Matthew Drury )那里得到一些启发之后,我构建了一个更为优雅的解决方案。
再次将问题归结为:
我有一个矩阵完成方法,该方法X
, Y
,和R
作为参数,并试图构建W
和H
最小化R - XWHY
在所有观察到的指数R
fit
方法的基本实现如下所示:
def fit(X, Y, R):
W, H = do_minimization(X, Y, R)
return W, H
这与标准的sklearn模型不太吻合,在标准的sklearn模型中,拟合需要一个X
(输入模型的要素)和y
(结果),看起来像这样:
def fit(X, y):
W, H = do_minimization(X, y)
return W, H
在您开始使用GridSearchCV
或其他交叉验证方法之前,这并不是真正的问题,因为他们希望数据适合后一种格式。 因此,要将这两个概念结合起来,我需要一种将两个不同的矩阵X
和Y
打包到一个结构中的方法,而又不会失去二者的独立性。
在最初的5分钟中,我不得不致力于这一点,因此我想出了解决方案。 在矩阵R
形状n, m
中,行对应于X
的记录,列对应于Y
的记录,总共有b
个条目。 如果我们为所有这些条目获取行索引和列索引,并在行上索引X
,在列上索引Y
,那么对于X
和Y
将得到等长的矩阵。 然后可以将它们水平堆叠,用一堆废话隔开,然后毫无问题地传递给交叉验证方法(我们只需要在原始类内部使用几个辅助方法,即可在拟合之前从堆栈中重建原始X
和Y
这个问题的重点是找到优雅的解决方案,或者最好是现有的解决方案。 似乎并非如此,因此我将为从sklearn继承而构建的任何未来估计器/分类器提出以下模型,这些估计器/分类器不仅仅需要fit方法的单个特征矩阵。
使用GridSearchCV
, fit
方法会进行一轮检查,然后触发对估算器fit
方法的所有调用。 其中之一确定传递的X
数组是否可索引 。 该测试基本上检查X
实现__getitem__
或iloc
并且长度与y
相同。 此长度检查要求X
具有shape
属性。 那时,分裂指数和拟合度可以按预期计算。 因此,我们需要一个实现__getitem__
并具有shape
属性的包装器。
class DataHandler(object):
def __init(self, X, Y):
self.X = X
self.Y = Y
self.shape = self.X.shape
def __getitem__(self, x):
return self.X[x], self.Y[x]
而已! 现在,我们可以通过修改fit
方法来匹配sklearn风格,但在这种情况下,而不是X
是一个数组,它要么是一个元组(由返回的结果__getitem__
方法)或我们的一个实例DataHandler
类。
现在,仅通过传递包含X
和Y
数组的DataHandler
实例, GridSearchCV
按预期工作。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.