繁体   English   中英

如何在 AWS sagemaker >= 2.0 中更新现有的 model

[英]How to update an existing model in AWS sagemaker >= 2.0

我有一个 XGBoost model 目前正在使用 AWS sagemaker 进行生产并进行实时推理。 过了一会儿,我想用一个经过更多数据训练的更新的 model 更新并保持一切不变(例如,相同的端点,相同的推理过程,所以除了 model 本身之外真的没有任何变化)

当前的部署过程如下:

from sagemaker.xgboost.model import XGBoostModel
from sagemaker.xgboost.model import XGBoostPredictor

xgboost_model = XGBoostModel(
    model_data = <S3 url>,
    role = <sagemaker role>,
    entry_point = 'inference.py',
    source_dir = 'src',
    code_location = <S3 url of other dependencies>
    framework_version='1.5-1',
    name = model_name)

xgboost_model.deploy(
    instance_type='ml.c5.large',
    initial_instance_count=1,
    endpoint_name = model_name)

现在我在几周后更新了 model,我想重新部署它。 我知道.deploy()方法创建了一个端点和一个端点配置,因此它可以完成所有工作。 我不能简单地重新运行我的脚本,因为我会遇到错误。

在以前版本的 sagemaker 中,我可以更新 model 并使用一个额外的参数传递给.deploy()方法,称为update_endpoint = True 在 sagemaker >=2.0 中,这是一个无操作。 现在,在 sagemaker >= 2.0 中,我需要使用文档中所述的预测器 object。 所以我尝试以下方法:

predictor = XGBoostPredictor(model_name)
predictor.update_endpoint(model_name= model_name)

它实际上根据新的端点配置更新端点。 但是,我不知道它在更新什么......我没有在上面的 2 行代码中指定我们需要考虑使用更多数据训练的新xgboost_model ......所以我在哪里告诉更新需要更多最近的 model?

谢谢!

更新

我相信我需要查看此处文档中所述的生产变体。 然而,他们的整个教程是基于亚马逊 sdk 的 python (boto3) 的,当我对每个 model 变体都有不同的入口点时,这些工件很难管理。py 脚本(例如不同的inference.py )。

在您的 model_name 中,您指定 SageMaker Model object 的名称,您可以在其中指定 image_uri、model_data 等。

由于我找到了自己问题的答案,因此我将在此处发布给遇到相同问题的人。

我最终使用 boto3 SDK 而不是 sagemaker SDK(或一些文档建议的两者的混合)重新编码了我的所有部署脚本。

这是显示如何创建 sagemaker model object、端点配置和端点以首次部署 model 的整个脚本。 此外,它还展示了如何使用更新的 model 更新端点(这是我的主要问题)

如果您想带上自己的 model 并使用 sagemaker 在生产中安全地更新它,下面是执行所有 3 项的代码:

import boto3
import time
from datetime import datetime
from sagemaker import image_uris
from fileManager import *  # this is a local script for helper functions

# name of zipped model and zipped inference code
CODE_TAR = 'your_inference_code_and_other_artifacts.tar.gz'
MODEL_TAR = 'your_saved_xgboost_model.tar.gz'

# sagemaker params
smClient = boto3.client('sagemaker')
smRole = <your_sagemaker_role>
bucket = sagemaker.Session().default_bucket()

# deploy algorithm
class Deployer:

    def __init__(self, modelName, deployRetrained=False):
        self.modelName=modelName
        self.deployRetrained = deployRetrained
        self.prefix = <S3_model_path_prefix>
    
    def deploy(self):
        '''
        Main method to create a sagemaker model, create an endpoint configuration and deploy the model. If deployRetrained
        param is set to True, this method will update an already existing endpoint.
        '''
        # define model name and endpoint name to be used for model deployment/update
        model_name = self.modelName + <any_suffix>
        endpoint_config_name = self.modelName + '-%s' %datetime.now().strftime('%Y-%m-%d-%HH%M')
        endpoint_name = self.modelName
        
        # deploy model for the first time
        if not self.deployRetrained:
            print('Deploying for the first time')

            # here you should copy and zip the model dependencies that you may have (such as preprocessors, inference code, config code...)
            # mine were zipped into the file called CODE_TAR

            # upload model and model artifacts needed for inference to S3
            uploadFile(list_files=[MODEL_TAR, CODE_TAR], prefix = self.prefix)

            # create sagemaker model and endpoint configuration
            self.createSagemakerModel(model_name)
            self.createEndpointConfig(endpoint_config_name, model_name)

            # deploy model and wait while endpoint is being created
            self.createEndpoint(endpoint_name, endpoint_config_name)
            self.waitWhileCreating(endpoint_name)
        
        # update model
        else:
            print('Updating existing model')

            # upload model and model artifacts needed for inference (here the old ones are replaced)
            # make sure to make a backup in S3 if you would like to keep the older models
            # we replace the old ones and keep the same names to avoid having to recreate a sagemaker model with a different name for the update!
            uploadFile(list_files=[MODEL_TAR, CODE_TAR], prefix = self.prefix)

            # create a new endpoint config that takes the new model
            self.createEndpointConfig(endpoint_config_name, model_name)

            # update endpoint
            self.updateEndpoint(endpoint_name, endpoint_config_name)

            # wait while endpoint updates then delete outdated endpoint config once it is InService
            self.waitWhileCreating(endpoint_name)
            self.deleteOutdatedEndpointConfig(model_name, endpoint_config_name)

    def createSagemakerModel(self, model_name):
        ''' 
        Create a new sagemaker Model object with an xgboost container and an entry point for inference using boto3 API
        '''
        # Retrieve that inference image (container)
        docker_container = image_uris.retrieve(region=region, framework='xgboost', version='1.5-1')

        # Relative S3 path to pre-trained model to create S3 model URI
        model_s3_key = f'{self.prefix}/'+ MODEL_TAR

        # Combine bucket name, model file name, and relate S3 path to create S3 model URI
        model_url = f's3://{bucket}/{model_s3_key}'

        # S3 path to the necessary inference code
        code_url = f's3://{bucket}/{self.prefix}/{CODE_TAR}'
        
        # Create a sagemaker Model object with all its artifacts
        smClient.create_model(
            ModelName = model_name,
            ExecutionRoleArn = smRole,
            PrimaryContainer = {
                'Image': docker_container,
                'ModelDataUrl': model_url,
                'Environment': {
                    'SAGEMAKER_PROGRAM': 'inference.py', #inference.py is at the root of my zipped CODE_TAR
                    'SAGEMAKER_SUBMIT_DIRECTORY': code_url,
                }
            }
        )
    
    def createEndpointConfig(self, endpoint_config_name, model_name):
        ''' 
        Create an endpoint configuration (only for boto3 sdk procedure) and set production variants parameters.
        Each retraining procedure will induce a new variant name based on the endpoint configuration name.
        '''
        smClient.create_endpoint_config(
            EndpointConfigName=endpoint_config_name,
            ProductionVariants=[
                {
                    'VariantName': endpoint_config_name,
                    'ModelName': model_name,
                    'InstanceType': INSTANCE_TYPE,
                    'InitialInstanceCount': 1
                }
            ]
        )

    def createEndpoint(self, endpoint_name, endpoint_config_name):
        '''
        Deploy the model to an endpoint
        '''
        smClient.create_endpoint(
            EndpointName=endpoint_name,
            EndpointConfigName=endpoint_config_name)
    
    def deleteOutdatedEndpointConfig(self, name_check, current_endpoint_config):
        '''
        Automatically detect and delete endpoint configurations that contain a string 'name_check'. This method can be used
        after a retrain procedure to delete all previous endpoint configurations but keep the current one named 'current_endpoint_config'.
        '''
        # get a list of all available endpoint configurations
        all_configs = smClient.list_endpoint_configs()['EndpointConfigs']

        # loop over the names of endpoint configs
        names_list = []
        for config_dict in all_configs:
            endpoint_config_name = config_dict['EndpointConfigName']

            # get only endpoint configs that contain name_check in them and save names to a list
            if name_check in endpoint_config_name:
                names_list.append(endpoint_config_name)
        
        # remove the current endpoint configuration from the list (we do not want to detele this one since it is live)
        names_list.remove(current_endpoint_config)

        for name in names_list:
            try:
                smClient.delete_endpoint_config(EndpointConfigName=name)
                print('Deleted endpoint configuration for %s' %name)
            except:
                print('INFO : No endpoint configuration was found for %s' %endpoint_config_name)

    def updateEndpoint(self, endpoint_name, endpoint_config_name):
        ''' 
        Update existing endpoint with a new retrained model
        '''
        smClient.update_endpoint(
            EndpointName=endpoint_name,
            EndpointConfigName=endpoint_config_name,
            RetainAllVariantProperties=True)
    
    def waitWhileCreating(self, endpoint_name):
        ''' 
        While the endpoint is being created or updated sleep for 60 seconds.
        '''
        # wait while creating or updating endpoint
        status = smClient.describe_endpoint(EndpointName=endpoint_name)['EndpointStatus']
        print('Status: %s' %status)
        while status != 'InService' and status !='Failed':
            time.sleep(60)
            status = smClient.describe_endpoint(EndpointName=endpoint_name)['EndpointStatus']
            print('Status: %s' %status)
        
        # in case of a deployment failure raise an error
        if status == 'Failed':
            raise ValueError('Endpoint failed to deploy')

if __name__=="__main__":
    deployer = Deployer('churnmodel', deployRetrained=True)
    deployer.deploy()

最后评论:

  • sagemaker 文档提到了所有这些,但没有提到 state,您可以为create_model方法提供“entry_point”以及用于推理依赖项(例如规范化工件)的“source_dir”。 可以按照PrimaryContainer参数中的说明来完成。

  • 我的fileManager.py脚本只包含制作 tar 文件、上传和下载到我的 S3 路径的基本功能。 为了简化 class,我没有将它们包括在内。

  • 方法 deleteOutdatedEndpointConfig 可能看起来像一个不必要的循环和检查的过度杀伤,我这样做是因为我有多个端点配置要处理并且想要删除那些不存在的并且包含字符串name_check (我不知道确切的名称的配置,因为有一个日期时间后缀)。 如果您愿意,请随意简化它或将其全部删除。

希望能帮助到你。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM