繁体   English   中英

如何使用带有字符串张量输入的张量流服务器保存的模型在本地机器中进行预测?

[英]How to user a tensorflow server saved model with string tensor input to predict in the local machine?

我正在尝试在我的本地机器上运行保存的服务模型。 但是,它需要字符串张量作为输入,我无法将图像转换为正确的字符串格式。

要加载我使用的模型:

saved_model = tf.saved_model.load('model/1/')
inf_model = saved_model.signatures['serving_default']

该模型具有以下输入-输出结构:

inputs {
  key: "encoded"
  value {
    name: "serving_default_encoded:0"
    dtype: DT_STRING
    tensor_shape {
    }
  }
}
outputs {
  key: "output_0"
  value {
    name: "StatefulPartitionedCall:0"
    dtype: DT_FLOAT
    tensor_shape {
      dim {
        size: 19451
      }
    }
  }
}

method_name: "tensorflow/serving/predict" 要处理图像,我使用这个:

    img = tf.io.read_file(path)
    # Decodes the image to W x H x 3 shape tensor with type of uint8
    img = tf.io.decode_image(img, channels=3)
    img = tf.image.resize_with_pad(img, 224, 224)
    img = tf.image.convert_image_dtype(img, tf.float32)

我尝试将其转换为字符串张量格式,如下所示:

    img_encoded = base64.urlsafe_b64encode(img).decode("utf-8")
    img_encoded = tf.constant(img_encoded)

预测:

    pred = inf_model(encoded=enc)['sequential_1'][0]

但是,我收到以下错误:

Traceback (most recent call last):

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/spyder_kernels/py3compat.py", line 356, in compat_exec
    exec(code, globals, locals)

  File "/home/james/Desktop/Project/dev_test/inference.py", line 79, in <module>
    res = inf_model(encoded=enc)['sequential_1'][0]

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1669, in __call__
    return self._call_impl(args, kwargs)

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1678, in _call_impl
    return self._call_with_structured_signature(args, kwargs,

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1759, in _call_with_structured_signature
    return self._call_flat(

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py", line 115, in _call_flat
    return super(_WrapperFunction, self)._call_flat(args, captured_inputs,

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1918, in _call_flat
    return self._build_call_outputs(self._inference_function.call(

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 555, in call
    outputs = execute.execute(

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,

InvalidArgumentError:  Unknown image file format. One of JPEG, PNG, GIF, BMP required.
     [[{{node StatefulPartitionedCall/decode_image/DecodeImage}}]] [Op:__inference_signature_wrapper_129216]

Function call stack:
signature_wrapper

错误是由于图像格式,不是JPEG、PNG、GIF、BMP 格式。 图像可能具有扩展名 jpg,但采用 tiff 格式。 由于错误表明图像文件格式未知,请检查图像类型并使用以下代码从数据集中删除不是 JPEG、PNG、GIF、BMP 类型的图像;

import os
import cv2
import imghdr

def check_images( s_dir, ext_list):
    bad_images=[]
    bad_ext=[]
    s_list= os.listdir(s_dir)
    for klass in s_list:
        klass_path=os.path.join (s_dir, klass)
        print ('processing class directory ', klass)
        if os.path.isdir(klass_path):
            file_list=os.listdir(klass_path)
            for f in file_list:               
                f_path=os.path.join (klass_path,f)
                tip = imghdr.what(f_path)
                if ext_list.count(tip) == 0:
                  bad_images.append(f_path)
                if os.path.isfile(f_path):
                    try:
                        img=cv2.imread(f_path)
                        shape=img.shape
                    except:
                        print('file ', f_path, ' is not a valid image file')
                        bad_images.append(f_path)
                else:
                    print('*** fatal error, you a sub directory ', f, ' in class directory ', klass)
        else:
            print ('*** WARNING*** you have files in ', s_dir, ' it should only contain sub directories')
    return bad_images, bad_ext

source_dir =r'c:\temp\people\storage'
good_exts=['jpg', 'png', 'jpeg', 'gif', 'bmp' ] # list of acceptable extensions
bad_file_list, bad_ext_list=check_images(source_dir, good_exts)
if len(bad_file_list) !=0:
    print('improper image files are listed below')
    for i in range (len(bad_file_list)):
        print (bad_file_list[i])
else:
    print(' no improper image files were found') 

从数据集中删除此类图像会有所帮助。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM