簡體   English   中英

scikit-learn 中的多標簽分類與超參數搜索:指定平均

[英]Multilabel classification in scikit-learn with hyperparameter search: specifying averaging

我正在研究一個簡單的多輸出分類問題,並注意到每當運行以下代碼時都會出現此錯誤:

ValueError: Target is multilabel-indicator but average='binary'. Please 
choose another average setting, one of [None, 'micro', 'macro', 'weighted', 'samples'].

我理解它所引用的問題,即在評估多標簽模型時,需要明確設置平均類型。 盡管如此,我無法弄清楚這個average參數應該 go 到哪里,因為只有accuracy_scoreprecision_scorerecall_score內置方法有這個參數,我沒有在我的代碼中明確使用。 此外,由於我正在執行RandomizedSearch ,因此我不能只將precision_score(average='micro')傳遞給scoringrefit調整 arguments ,因為precision_score()需要傳遞正確且真實的y標簽。 這就是為什么這個以前的 SO question這個 here都有類似的問題,沒有幫助。

我的示例數據生成代碼如下:

from sklearn.datasets import make_multilabel_classification
from sklearn.naive_bayes import MultinomialNB
from sklearn.multioutput import MultiOutputClassifier
from sklearn.model_selection import RandomizedSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler

X, Y = make_multilabel_classification(
    n_samples=1000,
    n_features=2,
    n_classes=5,
    n_labels=2
)

pipe = Pipeline(
    steps = [
        ('scaler', MinMaxScaler()),
        ('model', MultiOutputClassifier(MultinomialNB()))
    ]
)

search = RandomizedSearchCV(
    estimator = pipe,
    param_distributions={'model__estimator__alpha': (0.01,1)},
    scoring = ['accuracy', 'precision', 'recall'],
    refit = 'precision',
    cv = 5
).fit(X, Y)

我錯過了什么?

從 scikit-learn 文檔中,我看到您可以傳遞一個可調用的函數,該函數返回一個字典,其中鍵是指標名稱,值是指標分數。 這意味着您可以編寫自己的評分 function,它必須將估計器X_testy_test作為輸入。 這反過來必須計算 y_pred 並使用它來計算您想要使用的分數。 這你可以做內置的方法。 在那里,您可以指定應該使用哪個關鍵字 arguments 來計算分數。 在看起來像的代碼中

def my_scorer(estimator, X_test, y_test) -> dict[str, float]:
    y_pred = estimator.predict(X_test)
    return {
        'accuracy': accuracy_score(y_test, y_pred),
        'precision': precision_score(y_test, y_pred, average='micro'),
        'recall': recall_score(y_test, y_pred, average='micro'),
    }

search = RandomizedSearchCV(
    estimator = pipe,
    param_distributions={'model__estimator__alpha': (0.01,1)},
    scoring = my_scorer,
    refit = 'precision',
    cv = 5
).fit(X, Y)

評分指標表中,請注意f1_microf1_macro等,以及針對precisionrecall給出的注釋“后綴適用於 'f1'”。 所以例如

search = RandomizedSearchCV(
    ...
    scoring = ['accuracy', 'precision_micro', 'recall_macro'],
    ...
)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM