简体   繁体   English

GridSearchCV用于分别为每个标签进行多标签分类

[英]GridSearchCV for multi-label classification for each label separately

I am doing multi-label classification using scikit learn. 我正在使用scikit学习进行多标签分类。 I am using RandomForestClassifier as the base estimator. 我正在使用RandomForestClassifier作为基本估计量。 I want to optimize the parameters of it for each label using GridSearchCV. 我想使用GridSearchCV为每个标签优化它的参数。 Currently I am doing it in the following way: 目前,我正在通过以下方式进行操作:

from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.grid_search import GridSearchCV

parameters = {
  "estimator__n_estimators": [5, 50, 200],
  "estimator__max_depth" : [None, 10,20],
  "estimator__min_samples_split" : [2, 5, 10],
}
model_to_tune = OneVsRestClassifier(RandomForestClassifier(random_state=0,class_weight='auto'))
model_tuned = GridSearchCV(model_to_tune, param_grid=params, scoring='f1',n_jobs=2)
print model_tuned.best_params_
{'estimator__min_samples_split': 10, 'estimator__max_depth': None, 'estimator__n_estimators': 200}

These are the parameters which gives the best f1 score considering all the labels. 考虑到所有标签,这些参数可提供最佳的f1分数。 I want to find the parameters separately for each label. 我想为每个标签分别查找参数。 Is there any built in function which can do that? 有内置的功能可以做到这一点吗?

It's not hard to do that, though it is not built-in and I'm not sure I understand why you would want to. 尽管它不是内置的,但这样做并不难,我不确定我为什么会这么做。

Simply pre-process your data like so: 只需像这样预处理数据:

for a_class in list_of_unique_classes:
    y_this_class = (y_all_class==a_class)
    model_to_tune = RandomForestClassifier(random_state=0,class_weight='auto')
    model_tuned = GridSearchCV(model_to_tune, param_grid=params, scoring='f1',n_jobs=2)
    model_tuned.fit( X, y_this_class )

    # Save the best parameters for this class

(Also, beware f1 score, it does not do a good job of describing performance of a classifier for skewed data sets. You want to use ROC curves and/or informedness ). (另外,请注意f1分数,它不能很好地描述偏斜数据集的分类器性能。您想使用ROC曲线和/或知情度 )。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM