[英]Multilabel classification in scikit-learn with hyperparameter search: specifying averaging
我正在研究一個簡單的多輸出分類問題,並注意到每當運行以下代碼時都會出現此錯誤:
ValueError: Target is multilabel-indicator but average='binary'. Please
choose another average setting, one of [None, 'micro', 'macro', 'weighted', 'samples'].
我理解它所引用的問題,即在評估多標簽模型時,需要明確設置平均類型。 盡管如此,我無法弄清楚這個average
參數應該 go 到哪里,因為只有accuracy_score
、 precision_score
、 recall_score
內置方法有這個參數,我沒有在我的代碼中明確使用。 此外,由於我正在執行RandomizedSearch
,因此我不能只將precision_score(average='micro')
傳遞給scoring
或refit
調整 arguments ,因為precision_score()
需要傳遞正確且真實的y
標簽。 這就是為什么這個以前的 SO question和這個 here都有類似的問題,沒有幫助。
我的示例數據生成代碼如下:
from sklearn.datasets import make_multilabel_classification
from sklearn.naive_bayes import MultinomialNB
from sklearn.multioutput import MultiOutputClassifier
from sklearn.model_selection import RandomizedSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler
X, Y = make_multilabel_classification(
n_samples=1000,
n_features=2,
n_classes=5,
n_labels=2
)
pipe = Pipeline(
steps = [
('scaler', MinMaxScaler()),
('model', MultiOutputClassifier(MultinomialNB()))
]
)
search = RandomizedSearchCV(
estimator = pipe,
param_distributions={'model__estimator__alpha': (0.01,1)},
scoring = ['accuracy', 'precision', 'recall'],
refit = 'precision',
cv = 5
).fit(X, Y)
我錯過了什么?
從 scikit-learn 文檔中,我看到您可以傳遞一個可調用的函數,該函數返回一個字典,其中鍵是指標名稱,值是指標分數。 這意味着您可以編寫自己的評分 function,它必須將估計器X_test
和y_test
作為輸入。 這反過來必須計算 y_pred 並使用它來計算您想要使用的分數。 這你可以做內置的方法。 在那里,您可以指定應該使用哪個關鍵字 arguments 來計算分數。 在看起來像的代碼中
def my_scorer(estimator, X_test, y_test) -> dict[str, float]:
y_pred = estimator.predict(X_test)
return {
'accuracy': accuracy_score(y_test, y_pred),
'precision': precision_score(y_test, y_pred, average='micro'),
'recall': recall_score(y_test, y_pred, average='micro'),
}
search = RandomizedSearchCV(
estimator = pipe,
param_distributions={'model__estimator__alpha': (0.01,1)},
scoring = my_scorer,
refit = 'precision',
cv = 5
).fit(X, Y)
從評分指標表中,請注意f1_micro
、 f1_macro
等,以及針對precision
和recall
給出的注釋“后綴適用於 'f1'”。 所以例如
search = RandomizedSearchCV(
...
scoring = ['accuracy', 'precision_micro', 'recall_macro'],
...
)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.