繁体   English   中英

Sklearn:交叉验证分组数据

[英]Sklearn: Cross validation for grouped data

我正在尝试对分组数据实施交叉验证方案。 我本来希望使用GroupKFold方法,但一直出现错误。 我究竟做错了什么? 代码(与我使用的代码略有不同-我有不同的数据,所以我有一个更大的n_splits,但其他所有内容都是相同的)

from sklearn import metrics
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import GroupKFold
from sklearn.grid_search import GridSearchCV
from xgboost import XGBRegressor
#generate data
x=np.array([0,1,2,3,4,5,6,7,8,9,10,11,12,13])
y= np.array([1,2,3,4,5,6,7,1,2,3,4,5,6,7])
group=np.array([1,0,1,1,2,2,2,1,1,1,2,0,0,2)]
#grid search
gkf = GroupKFold( n_splits=3).split(x,y,group)
subsample = np.arange(0.3,0.5,0.1)
param_grid = dict( subsample=subsample)
rgr_xgb = XGBRegressor(n_estimators=50)
grid_search = GridSearchCV(rgr_xgb, param_grid, cv=gkf, n_jobs=-1)
result = grid_search.fit(x, y)

错误:

Traceback (most recent call last):

File "<ipython-input-143-11d785056a08>", line 8, in <module>
result = grid_search.fit(x, y)

File "/home/student/anaconda/lib/python3.5/site-packages/sklearn/grid_search.py", line 813, in fit
return self._fit(X, y, ParameterGrid(self.param_grid))

 File "/home/student/anaconda/lib/python3.5/site-packages/sklearn/grid_search.py", line 566, in _fit
n_folds = len(cv)

TypeError: object of type 'generator' has no len()

换线

gkf = GroupKFold( n_splits=3).split(x,y,group)

gkf = GroupKFold( n_splits=3)

也不起作用。 错误消息是:

'GroupKFold' object is not iterable

GroupKFoldsplit函数GroupKFold 生成一对训练和测试索引。 您应该在拆分值上调用list以将它们全部包含在列表中,以便可以计算长度:

gkf = list(GroupKFold( n_splits=3).split(x,y,group))

我认为这是使用分组K折叠简历生成器进行的sklearn网格搜索的重复。

我说GridSearchCV对象的fit方法为此采用了一个groups参数。 参见文档https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn.model_selection.GridSearchCV.fit

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM