簡體   English   中英

如何使用帶有字符串張量輸入的張量流服務器保存的模型在本地機器中進行預測?

[英]How to user a tensorflow server saved model with string tensor input to predict in the local machine?

我正在嘗試在我的本地機器上運行保存的服務模型。 但是,它需要字符串張量作為輸入,我無法將圖像轉換為正確的字符串格式。

要加載我使用的模型:

saved_model = tf.saved_model.load('model/1/')
inf_model = saved_model.signatures['serving_default']

該模型具有以下輸入-輸出結構:

inputs {
  key: "encoded"
  value {
    name: "serving_default_encoded:0"
    dtype: DT_STRING
    tensor_shape {
    }
  }
}
outputs {
  key: "output_0"
  value {
    name: "StatefulPartitionedCall:0"
    dtype: DT_FLOAT
    tensor_shape {
      dim {
        size: 19451
      }
    }
  }
}

method_name: "tensorflow/serving/predict" 要處理圖像,我使用這個:

    img = tf.io.read_file(path)
    # Decodes the image to W x H x 3 shape tensor with type of uint8
    img = tf.io.decode_image(img, channels=3)
    img = tf.image.resize_with_pad(img, 224, 224)
    img = tf.image.convert_image_dtype(img, tf.float32)

我嘗試將其轉換為字符串張量格式,如下所示:

    img_encoded = base64.urlsafe_b64encode(img).decode("utf-8")
    img_encoded = tf.constant(img_encoded)

預測:

    pred = inf_model(encoded=enc)['sequential_1'][0]

但是,我收到以下錯誤:

Traceback (most recent call last):

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/spyder_kernels/py3compat.py", line 356, in compat_exec
    exec(code, globals, locals)

  File "/home/james/Desktop/Project/dev_test/inference.py", line 79, in <module>
    res = inf_model(encoded=enc)['sequential_1'][0]

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1669, in __call__
    return self._call_impl(args, kwargs)

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1678, in _call_impl
    return self._call_with_structured_signature(args, kwargs,

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1759, in _call_with_structured_signature
    return self._call_flat(

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py", line 115, in _call_flat
    return super(_WrapperFunction, self)._call_flat(args, captured_inputs,

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1918, in _call_flat
    return self._build_call_outputs(self._inference_function.call(

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 555, in call
    outputs = execute.execute(

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,

InvalidArgumentError:  Unknown image file format. One of JPEG, PNG, GIF, BMP required.
     [[{{node StatefulPartitionedCall/decode_image/DecodeImage}}]] [Op:__inference_signature_wrapper_129216]

Function call stack:
signature_wrapper

錯誤是由於圖像格式,不是JPEG、PNG、GIF、BMP 格式。 圖像可能具有擴展名 jpg,但采用 tiff 格式。 由於錯誤表明圖像文件格式未知,請檢查圖像類型並使用以下代碼從數據集中刪除不是 JPEG、PNG、GIF、BMP 類型的圖像;

import os
import cv2
import imghdr

def check_images( s_dir, ext_list):
    bad_images=[]
    bad_ext=[]
    s_list= os.listdir(s_dir)
    for klass in s_list:
        klass_path=os.path.join (s_dir, klass)
        print ('processing class directory ', klass)
        if os.path.isdir(klass_path):
            file_list=os.listdir(klass_path)
            for f in file_list:               
                f_path=os.path.join (klass_path,f)
                tip = imghdr.what(f_path)
                if ext_list.count(tip) == 0:
                  bad_images.append(f_path)
                if os.path.isfile(f_path):
                    try:
                        img=cv2.imread(f_path)
                        shape=img.shape
                    except:
                        print('file ', f_path, ' is not a valid image file')
                        bad_images.append(f_path)
                else:
                    print('*** fatal error, you a sub directory ', f, ' in class directory ', klass)
        else:
            print ('*** WARNING*** you have files in ', s_dir, ' it should only contain sub directories')
    return bad_images, bad_ext

source_dir =r'c:\temp\people\storage'
good_exts=['jpg', 'png', 'jpeg', 'gif', 'bmp' ] # list of acceptable extensions
bad_file_list, bad_ext_list=check_images(source_dir, good_exts)
if len(bad_file_list) !=0:
    print('improper image files are listed below')
    for i in range (len(bad_file_list)):
        print (bad_file_list[i])
else:
    print(' no improper image files were found') 

從數據集中刪除此類圖像會有所幫助。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM