简体   繁体   中英

How to set SageMaker xgboost's eval_metric to f1?

I tried SageMaker's AutoPilot to solve a binary classification problem and I found it is using f1 as the evaluation metric. But when I tried to write some code without tuning like this:

xgb.set_hyperparameters(max_depth=5,
                        eta=0.2,
                        gamma=4,
                        min_child_weight=6,
                        subsample=0.8,
                        objective='binary:logistic',
                        eval_metric='f1',
                        num_round=100)

This generates the following error:

[2021-10-17:00:02:19:ERROR] Customer Error: Metric 'f1' is not supported. Parameter 'eval_metric' should be one of these options:'rmse', 'mae', 'logloss', 'error', 'merror', 'mlogloss', 'auc', 'ndcg', 'map', 'poisson-nloglik', 'gamma-nloglik', 'gamma-deviance', 'tweedie-nloglik'.

Since the autopilot was able to compute F1, I feel like it is supported in the hyperparameter setting in some fashion? Am I misunderstanding?

Any help is going to be appreciated.

You can define the metrics that you want to send to CloudWatch by specifying a list of metric names and regular expressions as the metric_definitions argument when you initialize an Estimator object. See it here: Documentation

import sagemaker
from sagemaker.estimator import Estimator

estimator = Estimator(
    image_uri="your-own-image-uri",
    role=sagemaker.get_execution_role(), 
    sagemaker_session=sagemaker.Session(),
    instance_count=1,
    instance_type='ml.c4.xlarge',
    metric_definitions=[
       {'Name': 'train:error', 'Regex': 'Train_error=(.*?);'},
       {'Name': 'validation:error', 'Regex': 'Valid_error=(.*?);'}
    ]
)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM