简体   繁体   English

使用 Sagemaker 调用具有预训练自定义端点的调用超时 PyTorch model [推理]

[英]Invocation timed out using Sagemaker to invoke endpoints with pretrained custom PyTorch model [Inference]

I have a pretrained model based on PyTorch (contextualized_topic_models) and have deployed it using AWS sagemaker script model. However, when I tried to invoke endpoints for inference, it always returns "Invocation timed out error" no matter what I tried.我有一个基于 PyTorch (contextualized_topic_models) 的预训练 model 并使用 AWS sagemaker 脚本 model 部署了它。但是,当我尝试调用端点进行推理时,无论我尝试什么,它总是返回“调用超时错误”。 I have tried different types of input and changing the input_fn() function but still it doesn't work.我尝试了不同类型的输入并更改了 input_fn() function 但它仍然不起作用。

I've tried to run my inference.py script on Colab (without connecting to the aws server) and each function seems to work perfectly fine with expected predictions returned.我尝试在 Colab 上运行我的 inference.py 脚本(不连接到 aws 服务器),每个 function 似乎都运行良好,返回了预期的预测。

I've been trying to debug this for 4 days now and even in my dream I thought about this issue... I'll be deeply grateful for any help.我已经尝试调试这个 4 天了,甚至在我的梦中我都在考虑这个问题......我将非常感谢任何帮助。

Here's my deployment script.这是我的部署脚本。

from sagemaker.pytorch.model import PyTorchModel

pytorch_model = PyTorchModel(
    model_data=pretrained_model_data,
    entry_point="inference.py",
    role=role,
    framework_version="1.8.1",
    py_version="py36",
    sagemaker_session=sess,
)

endpoint_name = "topic-modeling-inference"

# Deploy
predictor = pytorch_model.deploy(
initial_instance_count = 1,
instance_type = "ml.g4dn.xlarge",
endpoint_name = endpoint_name
)

Endpoint test (prediction) script端点测试(预测)脚本

# Test the model
import json
sm = boto3.client('sagemaker-runtime')
endpoint_name = "topic-modeling-inference"

prompt = [
    "Here is a piece of cake."
        ]

promptbody = [x.encode('utf-8') for x in prompt]
promptbody = promptbody[0]
#body= bytes(prompt[0], 'utf-8')
#tryout = prompt[0]


response = sm.invoke_endpoint(
EndpointName=endpoint_name,
ContentType="text/csv",
Body=promptbody 
#Body=tryout.encode(encoding='UTF-8')
)

print(response)

#result = json.loads(response['Body'].read().decode('utf-8'))
#print(result)

Part of my inference.py script我的 inference.py 脚本的一部分

def predict_fn(input_data, model):
    input_data_features = tp10.transform(text_for_contextual=input_data)
    topic_prediction = model.get_doc_topic_distribution(input_data_features, n_samples=20)
    topicID = np.argmax(topic_prediction)
    topicID = int(topicID.astype('str'))
    return topicID
    #prediction = model.get_topic_lists(20)[np.argmax(topic_prediction)]
    #return prediction

def input_fn(request_body, request_content_type):
    if request_content_type == "application/json":
        request = json.loads(request_body)
    else:
        request = request_body
    return request

def output_fn(prediction, response_content_type):
    if response_content_type == "application/json":
        response = str(json.dumps(prediction))
    else:
        response = str(json.dumps(prediction))
    return response

Any help or guidance will be wonderful.任何帮助或指导都会很棒。 Thank you in advance.先感谢您。

I would suggest to look into the CloudWatch logs of the endpoint to see if there are any invocations reaching the endpoint.我建议查看端点的 CloudWatch 日志,看看是否有任何调用到达端点。

If yes, see if they are sending a response back without any errors in the same log file.如果是,请查看他们是否在同一个日志文件中没有任何错误地发回响应。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM