[英]Creating and deploying pre-trained tensorflow model with pre-processing and post-processing in AWS SageMaker
I am trying to deploy the pre-trained MaskRCNN model of https://github.com/matterport/Mask_RCNN in SageMaker for prediction.我正在尝试在 SageMaker 中部署https://github.com/matterport/Mask_RCNN的预训练 MaskRCNN model 进行预测。
The problem is that the model uses numpy and scikit-learn as preprocessing before feeding inputs to the Keras layers, so just deploying the model as in this example would not work.问题是 model 使用 numpy 和 scikit-learn 作为预处理,然后再将输入馈送到 Keras 层,因此仅部署 883616889570588 在这个例子中是行不通的。
Things I have tried:我尝试过的事情:
431 Request Header Fields Too Large
is prompted, since the pre-processed data (formed by the resized images plus anchors) is way larger than the original data.431 Request Header Fields Too Large
,因为预处理数据(由调整大小的图像加上锚点形成)比原始数据大得多。entry_point.py
script with the input_handler()
function. This script looks like the following:input_handler()
function 创建一个entry_point.py
脚本。该脚本如下所示:def install(package: str):
""" pip install a package """
subprocess.check_call([sys.executable, "-q", "-m", "pip", "install", package])
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR'))
return parser.parse_known_args()
def input_handler(data, context):
""" Expects @data to be json. Pre-process the data to be ready to be fed to the Keras layers.
https://www.mikulskibartosz.name/custom-preprocessing-in-tensorflow-with-sagemaker-endpoints/ """
data_dec = np.array(json.loads(data['inputs']))
[molded_images, image_metas, anchors] = preprocess_images(data_dec)
inputs = { #input_image, input_image_meta and input_anchors are the names of the Input layers of the model
"inputs": {
"input_image": molded_images.tolist(),
"input_image_meta": image_metas.tolist(),
"input_anchors": anchors.tolist()
}
}
return json.dumps(inputs)
if __name__ == "__main__":
args, _ = parse_args()
install('scikit-image')
install('scipy')
And then creating the model and the endpoint in a SageMaker Notebook Instance:然后在 SageMaker Notebook 实例中创建 model 和端点:
from sagemaker.tensorflow.model import TensorFlowModel
model = TensorFlowModel(model_data = url_to_saved_model_s3,
role = sagemaker.get_execution_role(),
framework_version = '2.3.1',
entry_point = 'entry_point.py'
)
predictor = model.deploy(initial_instance_count=1, instance_type='ml.t2.medium') # create endpoint
But does not look like the input_handler()
is ever called.但看起来不像
input_handler()
被调用过。
Any help on how to deploy a model for inference that needs to pre-process non-tensors?关于如何部署 model 以进行需要预处理非张量的推理的任何帮助?
so you can work with numpy and sklearn in your script.因此您可以在脚本中使用 numpy 和 sklearn。 You merely need to create a requirements.txt with these packages in your code directory with your entry point script.
您只需要使用入口点脚本在您的代码目录中创建一个包含这些包的 requirements.txt。 As for the input_handler and output_handler functionality, you want to build it in the following format as shown in the SageMaker Tensorflow Serving Container .
至于 input_handler 和 output_handler 功能,您希望按照SageMaker Tensorflow Serving Container中所示的以下格式构建它。
import json
def input_handler(data, context):
""" Pre-process request input before it is sent to TensorFlow Serving REST API
Args:
data (obj): the request data, in format of dict or string
context (Context): an object containing request and configuration details
Returns:
(dict): a JSON-serializable dict that contains request body and headers
"""
if context.request_content_type == 'application/json':
# pass through json (assumes it's correctly formed)
d = data.read().decode('utf-8')
return d if len(d) else ''
if context.request_content_type == 'text/csv':
# very simple csv handler
return json.dumps({
'instances': [float(x) for x in data.read().decode('utf-8').split(',')]
})
raise ValueError('{{"error": "unsupported content type {}"}}'.format(
context.request_content_type or "unknown"))
def output_handler(data, context):
"""Post-process TensorFlow Serving output before it is returned to the client.
Args:
data (obj): the TensorFlow serving response
context (Context): an object containing request and configuration details
Returns:
(bytes, string): data to return to client, response content type
"""
if data.status_code != 200:
raise ValueError(data.content.decode('utf-8'))
response_content_type = context.accept_header
prediction = data.content
return prediction, response_content_type
Make sure to log each line of your functions, this will emit to CloudWatch and we will get an idea of the error that you are dealing with and you could see where your script is breaking.确保记录你的函数的每一行,这将发送到 CloudWatch,我们将了解你正在处理的错误,你可以看到你的脚本在哪里中断。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.