[英]How to update an existing model in AWS sagemaker >= 2.0
我有一個 XGBoost model 目前正在使用 AWS sagemaker 進行生產並進行實時推理。 過了一會兒,我想用一個經過更多數據訓練的更新的 model 更新並保持一切不變(例如,相同的端點,相同的推理過程,所以除了 model 本身之外真的沒有任何變化)
當前的部署過程如下:
from sagemaker.xgboost.model import XGBoostModel
from sagemaker.xgboost.model import XGBoostPredictor
xgboost_model = XGBoostModel(
model_data = <S3 url>,
role = <sagemaker role>,
entry_point = 'inference.py',
source_dir = 'src',
code_location = <S3 url of other dependencies>
framework_version='1.5-1',
name = model_name)
xgboost_model.deploy(
instance_type='ml.c5.large',
initial_instance_count=1,
endpoint_name = model_name)
現在我在幾周后更新了 model,我想重新部署它。 我知道.deploy()
方法創建了一個端點和一個端點配置,因此它可以完成所有工作。 我不能簡單地重新運行我的腳本,因為我會遇到錯誤。
在以前版本的 sagemaker 中,我可以更新 model 並使用一個額外的參數傳遞給.deploy()
方法,稱為update_endpoint = True
。 在 sagemaker >=2.0 中,這是一個無操作。 現在,在 sagemaker >= 2.0 中,我需要使用文檔中所述的預測器 object。 所以我嘗試以下方法:
predictor = XGBoostPredictor(model_name)
predictor.update_endpoint(model_name= model_name)
它實際上根據新的端點配置更新端點。 但是,我不知道它在更新什么......我沒有在上面的 2 行代碼中指定我們需要考慮使用更多數據訓練的新xgboost_model
......所以我在哪里告訴更新需要更多最近的 model?
謝謝!
更新
我相信我需要查看此處文檔中所述的生產變體。 然而,他們的整個教程是基於亞馬遜 sdk 的 python (boto3) 的,當我對每個 model 變體都有不同的入口點時,這些工件很難管理。py 腳本(例如不同的inference.py
)。
在您的 model_name 中,您指定 SageMaker Model object 的名稱,您可以在其中指定 image_uri、model_data 等。
由於我找到了自己問題的答案,因此我將在此處發布給遇到相同問題的人。
我最終使用 boto3 SDK 而不是 sagemaker SDK(或一些文檔建議的兩者的混合)重新編碼了我的所有部署腳本。
這是顯示如何創建 sagemaker model object、端點配置和端點以首次部署 model 的整個腳本。 此外,它還展示了如何使用更新的 model 更新端點(這是我的主要問題)
如果您想帶上自己的 model 並使用 sagemaker 在生產中安全地更新它,下面是執行所有 3 項的代碼:
import boto3
import time
from datetime import datetime
from sagemaker import image_uris
from fileManager import * # this is a local script for helper functions
# name of zipped model and zipped inference code
CODE_TAR = 'your_inference_code_and_other_artifacts.tar.gz'
MODEL_TAR = 'your_saved_xgboost_model.tar.gz'
# sagemaker params
smClient = boto3.client('sagemaker')
smRole = <your_sagemaker_role>
bucket = sagemaker.Session().default_bucket()
# deploy algorithm
class Deployer:
def __init__(self, modelName, deployRetrained=False):
self.modelName=modelName
self.deployRetrained = deployRetrained
self.prefix = <S3_model_path_prefix>
def deploy(self):
'''
Main method to create a sagemaker model, create an endpoint configuration and deploy the model. If deployRetrained
param is set to True, this method will update an already existing endpoint.
'''
# define model name and endpoint name to be used for model deployment/update
model_name = self.modelName + <any_suffix>
endpoint_config_name = self.modelName + '-%s' %datetime.now().strftime('%Y-%m-%d-%HH%M')
endpoint_name = self.modelName
# deploy model for the first time
if not self.deployRetrained:
print('Deploying for the first time')
# here you should copy and zip the model dependencies that you may have (such as preprocessors, inference code, config code...)
# mine were zipped into the file called CODE_TAR
# upload model and model artifacts needed for inference to S3
uploadFile(list_files=[MODEL_TAR, CODE_TAR], prefix = self.prefix)
# create sagemaker model and endpoint configuration
self.createSagemakerModel(model_name)
self.createEndpointConfig(endpoint_config_name, model_name)
# deploy model and wait while endpoint is being created
self.createEndpoint(endpoint_name, endpoint_config_name)
self.waitWhileCreating(endpoint_name)
# update model
else:
print('Updating existing model')
# upload model and model artifacts needed for inference (here the old ones are replaced)
# make sure to make a backup in S3 if you would like to keep the older models
# we replace the old ones and keep the same names to avoid having to recreate a sagemaker model with a different name for the update!
uploadFile(list_files=[MODEL_TAR, CODE_TAR], prefix = self.prefix)
# create a new endpoint config that takes the new model
self.createEndpointConfig(endpoint_config_name, model_name)
# update endpoint
self.updateEndpoint(endpoint_name, endpoint_config_name)
# wait while endpoint updates then delete outdated endpoint config once it is InService
self.waitWhileCreating(endpoint_name)
self.deleteOutdatedEndpointConfig(model_name, endpoint_config_name)
def createSagemakerModel(self, model_name):
'''
Create a new sagemaker Model object with an xgboost container and an entry point for inference using boto3 API
'''
# Retrieve that inference image (container)
docker_container = image_uris.retrieve(region=region, framework='xgboost', version='1.5-1')
# Relative S3 path to pre-trained model to create S3 model URI
model_s3_key = f'{self.prefix}/'+ MODEL_TAR
# Combine bucket name, model file name, and relate S3 path to create S3 model URI
model_url = f's3://{bucket}/{model_s3_key}'
# S3 path to the necessary inference code
code_url = f's3://{bucket}/{self.prefix}/{CODE_TAR}'
# Create a sagemaker Model object with all its artifacts
smClient.create_model(
ModelName = model_name,
ExecutionRoleArn = smRole,
PrimaryContainer = {
'Image': docker_container,
'ModelDataUrl': model_url,
'Environment': {
'SAGEMAKER_PROGRAM': 'inference.py', #inference.py is at the root of my zipped CODE_TAR
'SAGEMAKER_SUBMIT_DIRECTORY': code_url,
}
}
)
def createEndpointConfig(self, endpoint_config_name, model_name):
'''
Create an endpoint configuration (only for boto3 sdk procedure) and set production variants parameters.
Each retraining procedure will induce a new variant name based on the endpoint configuration name.
'''
smClient.create_endpoint_config(
EndpointConfigName=endpoint_config_name,
ProductionVariants=[
{
'VariantName': endpoint_config_name,
'ModelName': model_name,
'InstanceType': INSTANCE_TYPE,
'InitialInstanceCount': 1
}
]
)
def createEndpoint(self, endpoint_name, endpoint_config_name):
'''
Deploy the model to an endpoint
'''
smClient.create_endpoint(
EndpointName=endpoint_name,
EndpointConfigName=endpoint_config_name)
def deleteOutdatedEndpointConfig(self, name_check, current_endpoint_config):
'''
Automatically detect and delete endpoint configurations that contain a string 'name_check'. This method can be used
after a retrain procedure to delete all previous endpoint configurations but keep the current one named 'current_endpoint_config'.
'''
# get a list of all available endpoint configurations
all_configs = smClient.list_endpoint_configs()['EndpointConfigs']
# loop over the names of endpoint configs
names_list = []
for config_dict in all_configs:
endpoint_config_name = config_dict['EndpointConfigName']
# get only endpoint configs that contain name_check in them and save names to a list
if name_check in endpoint_config_name:
names_list.append(endpoint_config_name)
# remove the current endpoint configuration from the list (we do not want to detele this one since it is live)
names_list.remove(current_endpoint_config)
for name in names_list:
try:
smClient.delete_endpoint_config(EndpointConfigName=name)
print('Deleted endpoint configuration for %s' %name)
except:
print('INFO : No endpoint configuration was found for %s' %endpoint_config_name)
def updateEndpoint(self, endpoint_name, endpoint_config_name):
'''
Update existing endpoint with a new retrained model
'''
smClient.update_endpoint(
EndpointName=endpoint_name,
EndpointConfigName=endpoint_config_name,
RetainAllVariantProperties=True)
def waitWhileCreating(self, endpoint_name):
'''
While the endpoint is being created or updated sleep for 60 seconds.
'''
# wait while creating or updating endpoint
status = smClient.describe_endpoint(EndpointName=endpoint_name)['EndpointStatus']
print('Status: %s' %status)
while status != 'InService' and status !='Failed':
time.sleep(60)
status = smClient.describe_endpoint(EndpointName=endpoint_name)['EndpointStatus']
print('Status: %s' %status)
# in case of a deployment failure raise an error
if status == 'Failed':
raise ValueError('Endpoint failed to deploy')
if __name__=="__main__":
deployer = Deployer('churnmodel', deployRetrained=True)
deployer.deploy()
最后評論:
sagemaker 文檔提到了所有這些,但沒有提到 state,您可以為create_model
方法提供“entry_point”以及用於推理依賴項(例如規范化工件)的“source_dir”。 可以按照PrimaryContainer
參數中的說明來完成。
我的fileManager.py
腳本只包含制作 tar 文件、上傳和下載到我的 S3 路徑的基本功能。 為了簡化 class,我沒有將它們包括在內。
方法 deleteOutdatedEndpointConfig 可能看起來像一個不必要的循環和檢查的過度殺傷,我這樣做是因為我有多個端點配置要處理並且想要刪除那些不存在的並且包含字符串name_check
(我不知道確切的名稱的配置,因為有一個日期時間后綴)。 如果您願意,請隨意簡化它或將其全部刪除。
希望能幫助到你。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.