简体   繁体   中英

Loading trained model in to SageMaker Estimator

I've trained a custom model on sagemaker based on PyTorch estimator.
Training has been completed, and I verified that the model artifacts have been saved into s3 location.

I want to load my trained model into my sagemaker notebooks so I can perform analysis/inference so on ...

I did as below but I am not sure if this is the right method to do this as it asks for instance type, and to my knowledge, If I were to load the already trained estimator, I would need to declare which type of computing instance I use once I start deploying the model for inference.

estimator = PyTorch(
        model_data = ModelArtifact_S3_LOCATION,
        entry_point ='train.py',
        source_dir = 'code',
        role = role,
        framework_version = '1.5.0',
        py_version = 'py3',)

If training has been completed and you want to setup for inference then you want to point to your tar.gz model artifact file to create an endpoint or take your training estimator directly. The following code block is the general flow that you want to follow for training, inference, and predictions.

# Train my estimator
pytorch_estimator = PyTorch(entry_point='train_and_deploy.py',
                            instance_type='ml.p3.2xlarge',
                            instance_count=1,
                            framework_version='1.8.0',
                            py_version='py3')
pytorch_estimator.fit('s3://my_bucket/my_training_data/')

# Deploy my estimator to a SageMaker Endpoint and get a Predictor
predictor = pytorch_estimator.deploy(instance_type='ml.m4.xlarge',
                                     initial_instance_count=1)

# `data` is a NumPy array or a Python list.
# `response` is a NumPy array.
response = predictor.predict(data)

For more information check out the following link for deploying PyTorch models on SageMaker. https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#deploy-pytorch-models

I work for AWS & my opinions are my own

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM