简体   繁体   English

如何将输入/输出与 sagemaker 批量转换相匹配?

[英]How to match input/output with sagemaker batch transform?

I'm using sagemaker batch transform, with json input files.我正在使用 sagemaker 批量转换,输入文件为 json。 see below for sample input/output files.请参阅下面的示例输入/输出文件。 i have custom inference code below, and i'm using json.dumps to return prediction, but it's not returning json. I tried to use => "DataProcessing": {"JoinSource": "string", }, to match input and output. but i'm getting error that "unable to marshall...".我在下面有自定义推理代码,我正在使用 json.dumps 返回预测,但它没有返回 json。我尝试使用 => "DataProcessing": {"JoinSource": "string", }, 来匹配输入和output。但我收到“无法编组...”的错误。 I think because, the output_fn is returning array of list or just list and not json, that is why it is unable to match input with output.any suggestions on how should i return the data?我认为是因为 output_fn 正在返回列表数组或只是列表而不是 json,这就是为什么它无法将输入与 output 匹配。关于我应该如何返回数据的任何建议?

infernce code推理代码

def model_fn(model_dir):
...
def input_fn(data, content_type):
...
def predict_fn(data, model):
...
def output_fn(prediction, accept):
    if accept == "application/json":
        return json.dumps(prediction), mimetype=accept)
    raise RuntimeException("{} accept type is not supported by this script.".format(accept))

input file输入文件

{"data" : "input line  one" }
{"data" : "input line  two" }
....

output file output 档案

["output line  one" ]
["output line  two" ]
{
   "BatchStrategy": SingleRecord,
   "DataProcessing": { 
      "JoinSource": "string",
   },
   "MaxConcurrentTransforms": 3,
   "MaxPayloadInMB": 6,
   "ModelClientConfig": { 
      "InvocationsMaxRetries": 1,
      "InvocationsTimeoutInSeconds": 3600
   },
   "ModelName": "some-model",
   "TransformInput": { 
      "ContentType": "string",
      "DataSource": { 
         "S3DataSource": { 
            "S3DataType": "string",
            "S3Uri": "s3://bucket-sample"
         }
      },
      "SplitType": "Line"
   },
   "TransformJobName": "transform-job"
}

json.dumps will not convert your array to a dict structure and serialize it to a JSON String. json.dumps不会将您的数组转换为字典结构并将其序列化为 JSON 字符串。

What data type is prediction ? prediction是什么数据类型? Have you tested making sure prediction is a dict?您是否测试过确保prediction是一个命令?

You can confirm the data type by adding print(type(prediction)) to see the data type in the CloudWatch Logs.您可以通过添加print(type(prediction))来确认数据类型,以查看 CloudWatch Logs 中的数据类型。

If prediction is a list you can test the following:如果预测是一个list ,您可以测试以下内容:

def output_fn(prediction, accept):
    if accept == "application/json":

        my_dict = {'output': prediction}
        return json.dumps(my_dict), mimetype=accept)

    raise RuntimeException("{} accept type is not supported by this script.".format(accept))

DataProcessing and JoinSource are used to associate the data that is relevant to the prediction results in the output. It is not meant to be used to match the input and output format. DataProcessingJoinSource用于关联output中与预测结果相关的数据,并不是用来匹配输入和output格式的。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM