简体   繁体   English

如何通过 AWS Lambda ZC1C425268E17985D1AB5074 对 AWS SageMaker 上托管的 keras model 进行推断?

[英]How to make inference to a keras model hosted on AWS SageMaker via AWS Lambda function?

I have a pre-trained keras model which I have hosted on AWS using AWS SageMaker .我有一个预先训练的keras model 我使用AWS SageMakerAWS上托管。 I've got an endpoint and I can make successful predictions using the Amazon SageMaker Notebook instance .我有一个endpoint ,可以使用Amazon SageMaker Notebook instance进行成功的predictions

What I do there is that I serve a .PNG image like the following and the model gives me correct prediction.我在那里做的是提供如下所示的.PNG image ,并且 model 给了我正确的预测。

file= s3.Bucket(bucketname).download_file(filename_1, 'normal.png')
file_name_1='normal.png'


import sagemaker
from sagemaker.tensorflow.model import TensorFlowModel

endpoint = 'tensorflow-inference-0000-11-22-33-44-55-666' #endpoint

predictor=sagemaker.tensorflow.model.TensorFlowPredictor(endpoint, sagemaker_session)
data = np.array([resize(imread(file_name), (137, 310, 3))])
predictor.predict(data)

Now I wanted to make predictions using a mobile application .现在我想使用mobile application进行预测。 For that I have to wrote a Lambda function in python and attached an API gateway to it.为此,我必须在 python 中写一个Lambda function并将一个API gateway到它。 My Lambda function is the following.我的Lambda function如下。

import os
import sys

CWD = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, os.path.join(CWD, "lib"))

import json
import base64
import boto3
import numpy as np
from scipy import signal
from scipy.signal import butter, lfilter
from scipy.io import wavfile
import scipy.signal as sps
import io
from io import BytesIO
import matplotlib.pylab as plt
from matplotlib import pyplot as plt
import matplotlib.image as mpimg
from datetime import datetime
from skimage.io import imread
from skimage.transform import resize
from PIL import Image

ENDPOINT_NAME = 'tensorflow-inference-0000-11-22-33-44-55-666'
runtime= boto3.client('runtime.sagemaker')

def lambda_handler(event, context):
    s3 = boto3.client("s3")
    
    # retrieving data from event.
    get_file_content_from_postman = event["content"]
    
    # decoding data.
    decoded_file_name = base64.b64decode(get_file_content_from_postman)
    
    image = Image.open(io.BytesIO(decoded_file_name))

    data = np.array([resize(imread(image), (137, 310, 3))])
    
    response = runtime.invoke_endpoint(EndpointName=ENDPOINT_NAME, ContentType='text/csv', Body=data)
        
    result = json.loads(response['Body'].read().decode())
    
    return result

The third last line is giving me error 'PngImageFile' object has no attribute 'read' .最后第三行是给我错误'PngImageFile' object has no attribute 'read' Any idea what I am missing here?知道我在这里缺少什么吗?

If io.BytesIO(decoded_file_name) correctly represents your image data (though the name decoded_file_name suggests that its only file name, not actual image data), then you don't need to use PIL.如果io.BytesIO(decoded_file_name)正确表示您的图像数据(尽管名称decoded_file_name表明它只是文件名,而不是实际图像数据),那么您不需要使用 PIL。 Just use it directly:直接使用即可:

data = np.array([resize(imread(io.BytesIO(decoded_file_name)), (137, 310, 3))])

I was missing one thing which was causing this error.我错过了导致此错误的一件事。 After receiving the image data I used python list and then json.dump that list (of lists).收到图像数据后,我使用了 python 列表,然后使用json.dump该列表(列表中的)。 Below is the code for reference.以下是供参考的代码。

import os
import sys

CWD = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, os.path.join(CWD, "lib"))

import json
import base64
import boto3
import numpy as np
import io
from io import BytesIO
from skimage.io import imread
from skimage.transform import resize

# grab environment variable of Lambda Function
ENDPOINT_NAME = os.environ['ENDPOINT_NAME']
runtime= boto3.client('runtime.sagemaker')

def lambda_handler(event, context):
    s3 = boto3.client("s3")
    
    # retrieving data from event.
    get_file_content_from_postman = event["content"]
    
    # decoding data.
    decoded_file_name = base64.b64decode(get_file_content_from_postman)
    
    data = np.array([resize(imread(io.BytesIO(decoded_file_name)), (137, 310, 3))])
    
    payload = json.dumps(data.tolist())
    
    response = runtime.invoke_endpoint(EndpointName=ENDPOINT_NAME, ContentType='application/json', Body=payload)
        
    result = json.loads(response['Body'].read().decode())
    
    return result
        

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM