简体   繁体   中英

performing K-fold Cross Validation with scoring = 'f1 or Recall or Precision' for multi-class problem

I know this can easily be implemented for a binary classification problem. But it seems to be a bit tough in the case of a multi-class problem.

I have a dataset that is un-balanced and is an example of a 4-class classification problem. I have applied the RandomForestClassifier() on it to test various measures of the algorithm such as accuracy, precision, recall, f1_score, etc. Now I wanted to perform the K-fold Cross Validation on the training set with 10 splits and I want the 'scoring' parameter of the cross_val_score() function to be 'f1' instead of 'accuracy' .

My code:

# Random Forest
np.random.seed(123)
from sklearn.ensemble import RandomForestClassifier
classifier_RF = RandomForestClassifier(random_state = 0)
classifier_RF.fit(X_train, Y_train)

# Applying k-Fold Cross Validation
from sklearn.model_selection import cross_val_score
accuracies = cross_val_score(estimator = classifier_RF, X = X_train, y = Y_train, cv = 10, scoring = 'f1')
print("F1_Score: {:.2f} %".format(accuracies.mean()*100))
print("Standard Deviation: {:.2f} %".format(accuracies.std()*100))

However, when I try to run this code, I am getting an error as follows:

ValueError: Target is multiclass but average='binary'. Please choose another average setting, one of [None, 'micro', 'macro', 'weighted'].

I have tried setting the average parameter to 'weighted' in the cross_val_function() as follows:

accuracies = cross_val_score(estimator = classifier_RF, X = X_train, y = Y_train, cv = 10, scoring = 'f1', average = 'weighted')

but that's giving an error as follows:

TypeError: cross_val_score() got an unexpected keyword argument 'average'

The entire traceback is as follows:

Traceback (most recent call last):

  File "<ipython-input-1-ba4a5e1de09a>", line 97, in <module>
    accuracies = cross_val_score(estimator = classifier_RF, X = X_train, y = Y_train, cv = 10, scoring = 'f1')

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/sklearn/utils/validation.py", line 72, in inner_f
    return f(**kwargs)

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_validation.py", line 406, in cross_val_score
    error_score=error_score)

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/sklearn/utils/validation.py", line 72, in inner_f
    return f(**kwargs)

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_validation.py", line 248, in cross_validate
    for train, test in cv.split(X, y, groups))

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py", line 1048, in __call__
    if self.dispatch_one_batch(iterator):

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py", line 866, in dispatch_one_batch
    self._dispatch(tasks)

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py", line 784, in _dispatch
    job = self._backend.apply_async(batch, callback=cb)

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/joblib/_parallel_backends.py", line 208, in apply_async
    result = ImmediateResult(func)

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/joblib/_parallel_backends.py", line 572, in __init__
    self.results = batch()

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py", line 263, in __call__
    for func, args, kwargs in self.items]

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py", line 263, in <listcomp>
    for func, args, kwargs in self.items]

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_validation.py", line 560, in _fit_and_score
    test_scores = _score(estimator, X_test, y_test, scorer)

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_validation.py", line 607, in _score
    scores = scorer(estimator, X_test, y_test)

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/sklearn/metrics/_scorer.py", line 88, in __call__
    *args, **kwargs)

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/sklearn/metrics/_scorer.py", line 213, in _score
    **self._kwargs)

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/sklearn/utils/validation.py", line 72, in inner_f
    return f(**kwargs)

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/sklearn/metrics/_classification.py", line 1047, in f1_score
    zero_division=zero_division)

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/sklearn/utils/validation.py", line 72, in inner_f
    return f(**kwargs)

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/sklearn/metrics/_classification.py", line 1175, in fbeta_score
    zero_division=zero_division)

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/sklearn/utils/validation.py", line 72, in inner_f
    return f(**kwargs)

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/sklearn/metrics/_classification.py", line 1434, in precision_recall_fscore_support
    pos_label)

  File "/Users/vivekchowdary/opt/anaconda3/lib/python3.7/site-packages/sklearn/metrics/_classification.py", line 1265, in _check_set_wise_labels
    % (y_type, average_options))

ValueError: Target is multiclass but average='binary'. Please choose another average setting, one of [None, 'micro', 'macro', 'weighted'].

You need to use make_score to define your metric and its parameters:

from sklearn.metrics import make_scorer, f1_score

scoring = {'f1_score' : make_scorer(f1_score, average='weighted')}

and then use this in your cross_val_score :

results = cross_val_score(estimator = classifier_RF, 
                          X = X_train, 
                          y = Y_train, 
                          cv = 10, 
                          scoring = scoring)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM