![](/img/trans.png)
[英]How to build input to predict with saved model for BERT SQuAD2.0 with tensorflow
[英]How to user a tensorflow server saved model with string tensor input to predict in the local machine?
我正在嘗試在我的本地機器上運行保存的服務模型。 但是,它需要字符串張量作為輸入,我無法將圖像轉換為正確的字符串格式。
要加載我使用的模型:
saved_model = tf.saved_model.load('model/1/')
inf_model = saved_model.signatures['serving_default']
該模型具有以下輸入-輸出結構:
inputs {
key: "encoded"
value {
name: "serving_default_encoded:0"
dtype: DT_STRING
tensor_shape {
}
}
}
outputs {
key: "output_0"
value {
name: "StatefulPartitionedCall:0"
dtype: DT_FLOAT
tensor_shape {
dim {
size: 19451
}
}
}
}
method_name: "tensorflow/serving/predict" 要處理圖像,我使用這個:
img = tf.io.read_file(path)
# Decodes the image to W x H x 3 shape tensor with type of uint8
img = tf.io.decode_image(img, channels=3)
img = tf.image.resize_with_pad(img, 224, 224)
img = tf.image.convert_image_dtype(img, tf.float32)
我嘗試將其轉換為字符串張量格式,如下所示:
img_encoded = base64.urlsafe_b64encode(img).decode("utf-8")
img_encoded = tf.constant(img_encoded)
預測:
pred = inf_model(encoded=enc)['sequential_1'][0]
但是,我收到以下錯誤:
Traceback (most recent call last):
File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/spyder_kernels/py3compat.py", line 356, in compat_exec
exec(code, globals, locals)
File "/home/james/Desktop/Project/dev_test/inference.py", line 79, in <module>
res = inf_model(encoded=enc)['sequential_1'][0]
File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1669, in __call__
return self._call_impl(args, kwargs)
File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1678, in _call_impl
return self._call_with_structured_signature(args, kwargs,
File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1759, in _call_with_structured_signature
return self._call_flat(
File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py", line 115, in _call_flat
return super(_WrapperFunction, self)._call_flat(args, captured_inputs,
File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1918, in _call_flat
return self._build_call_outputs(self._inference_function.call(
File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 555, in call
outputs = execute.execute(
File "/home/james/anaconda3/envs/james/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
InvalidArgumentError: Unknown image file format. One of JPEG, PNG, GIF, BMP required.
[[{{node StatefulPartitionedCall/decode_image/DecodeImage}}]] [Op:__inference_signature_wrapper_129216]
Function call stack:
signature_wrapper
錯誤是由於圖像格式,不是JPEG、PNG、GIF、BMP 格式。 圖像可能具有擴展名 jpg,但采用 tiff 格式。 由於錯誤表明圖像文件格式未知,請檢查圖像類型並使用以下代碼從數據集中刪除不是 JPEG、PNG、GIF、BMP 類型的圖像;
import os
import cv2
import imghdr
def check_images( s_dir, ext_list):
bad_images=[]
bad_ext=[]
s_list= os.listdir(s_dir)
for klass in s_list:
klass_path=os.path.join (s_dir, klass)
print ('processing class directory ', klass)
if os.path.isdir(klass_path):
file_list=os.listdir(klass_path)
for f in file_list:
f_path=os.path.join (klass_path,f)
tip = imghdr.what(f_path)
if ext_list.count(tip) == 0:
bad_images.append(f_path)
if os.path.isfile(f_path):
try:
img=cv2.imread(f_path)
shape=img.shape
except:
print('file ', f_path, ' is not a valid image file')
bad_images.append(f_path)
else:
print('*** fatal error, you a sub directory ', f, ' in class directory ', klass)
else:
print ('*** WARNING*** you have files in ', s_dir, ' it should only contain sub directories')
return bad_images, bad_ext
source_dir =r'c:\temp\people\storage'
good_exts=['jpg', 'png', 'jpeg', 'gif', 'bmp' ] # list of acceptable extensions
bad_file_list, bad_ext_list=check_images(source_dir, good_exts)
if len(bad_file_list) !=0:
print('improper image files are listed below')
for i in range (len(bad_file_list)):
print (bad_file_list[i])
else:
print(' no improper image files were found')
從數據集中刪除此類圖像會有所幫助。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.