简体   繁体   中英

How to user a tensorflow server saved model with string tensor input to predict in the local machine?

I'm trying to run saved serving models in my local machine. However, it takes string tensor as input, and I'm having trouble converting the images to the correct string format.

To load the model I use:

saved_model = tf.saved_model.load('model/1/')
inf_model = saved_model.signatures['serving_default']

The model has the following input-output structure:

inputs {
  key: "encoded"
  value {
    name: "serving_default_encoded:0"
    dtype: DT_STRING
    tensor_shape {
    }
  }
}
outputs {
  key: "output_0"
  value {
    name: "StatefulPartitionedCall:0"
    dtype: DT_FLOAT
    tensor_shape {
      dim {
        size: 19451
      }
    }
  }
}

method_name: "tensorflow/serving/predict" To process the image I use this:

    img = tf.io.read_file(path)
    # Decodes the image to W x H x 3 shape tensor with type of uint8
    img = tf.io.decode_image(img, channels=3)
    img = tf.image.resize_with_pad(img, 224, 224)
    img = tf.image.convert_image_dtype(img, tf.float32)

And I try to convert it to string tensor format like this:

    img_encoded = base64.urlsafe_b64encode(img).decode("utf-8")
    img_encoded = tf.constant(img_encoded)

Predicting:

    pred = inf_model(encoded=enc)['sequential_1'][0]

However, I get the following error:

Traceback (most recent call last):

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/spyder_kernels/py3compat.py", line 356, in compat_exec
    exec(code, globals, locals)

  File "/home/james/Desktop/Project/dev_test/inference.py", line 79, in <module>
    res = inf_model(encoded=enc)['sequential_1'][0]

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1669, in __call__
    return self._call_impl(args, kwargs)

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1678, in _call_impl
    return self._call_with_structured_signature(args, kwargs,

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1759, in _call_with_structured_signature
    return self._call_flat(

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py", line 115, in _call_flat
    return super(_WrapperFunction, self)._call_flat(args, captured_inputs,

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1918, in _call_flat
    return self._build_call_outputs(self._inference_function.call(

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 555, in call
    outputs = execute.execute(

  File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,

InvalidArgumentError:  Unknown image file format. One of JPEG, PNG, GIF, BMP required.
     [[{{node StatefulPartitionedCall/decode_image/DecodeImage}}]] [Op:__inference_signature_wrapper_129216]

Function call stack:
signature_wrapper

The error is due to image format, which is not of JPEG, PNG, GIF, BMP formats. The images might have an extension name jpg but be in say a tiff format. As the error states that the image file format is unknown, kindly check the image type and delete the images which are not of JPEG, PNG, GIF, BMP type from your dataset using the code below;

import os
import cv2
import imghdr

def check_images( s_dir, ext_list):
    bad_images=[]
    bad_ext=[]
    s_list= os.listdir(s_dir)
    for klass in s_list:
        klass_path=os.path.join (s_dir, klass)
        print ('processing class directory ', klass)
        if os.path.isdir(klass_path):
            file_list=os.listdir(klass_path)
            for f in file_list:               
                f_path=os.path.join (klass_path,f)
                tip = imghdr.what(f_path)
                if ext_list.count(tip) == 0:
                  bad_images.append(f_path)
                if os.path.isfile(f_path):
                    try:
                        img=cv2.imread(f_path)
                        shape=img.shape
                    except:
                        print('file ', f_path, ' is not a valid image file')
                        bad_images.append(f_path)
                else:
                    print('*** fatal error, you a sub directory ', f, ' in class directory ', klass)
        else:
            print ('*** WARNING*** you have files in ', s_dir, ' it should only contain sub directories')
    return bad_images, bad_ext

source_dir =r'c:\temp\people\storage'
good_exts=['jpg', 'png', 'jpeg', 'gif', 'bmp' ] # list of acceptable extensions
bad_file_list, bad_ext_list=check_images(source_dir, good_exts)
if len(bad_file_list) !=0:
    print('improper image files are listed below')
    for i in range (len(bad_file_list)):
        print (bad_file_list[i])
else:
    print(' no improper image files were found') 

Removing such images from the dataset will help.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM