[英]Fine tuned keras .h5 model to .pb model used in c++ gives running model failed error
[英]Running Keras h5 model
我正在嘗試運行這個 h5 model 在這里找到的 ALASKA2 圖像隱寫分析比賽。
我想使用以下代碼預測 RGB 圖像c1.bmp
的 label:
import efficientnet.tfkeras as efn
import tensorflow as tf
from tensorflow import keras
import numpy as np
def decode_image(filename, image_size=(512, 512)):
bits = tf.io.read_file(filename)
image = tf.image.decode_bmp(bits, channels=3)
image = tf.cast(image, tf.float32) / 255.0
image = tf.image.resize(image, image_size)
return image
img = decode_image('imgs/c1.bmp')
model = keras.models.load_model("model.h5")
print(model.predict(img, verbose=1))
但是,運行此代碼會導致此錯誤:
File "alaska.py", line 20, in <module>
print(model.predict(img, verbose=1))
File "Python38\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1629, in predict
tmp_batch_outputs = self.predict_function(iterator)
File "Python38\lib\site-packages\tensorflow\python\eager\def_function.py", line 828, in __call__
result = self._call(*args, **kwds)
File "Python38\lib\site-packages\tensorflow\python\eager\def_function.py", line 871, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "Python38\lib\site-packages\tensorflow\python\eager\def_function.py", line 725, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File "Python38\lib\site-packages\tensorflow\python\eager\function.py", line 2969, in _get_concrete_function_internal_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "Python38\lib\site-packages\tensorflow\python\eager\function.py", line 3361, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "Python38\lib\site-packages\tensorflow\python\eager\function.py", line 3196, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "Python38\lib\site-packages\tensorflow\python\framework\func_graph.py", line 990, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "Python38\lib\site-packages\tensorflow\python\eager\def_function.py", line 634, in wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
File "Python38\lib\site-packages\tensorflow\python\framework\func_graph.py", line 977, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
Python38\lib\site-packages\tensorflow\python\keras\engine\training.py:1478 predict_function *
return step_function(self, iterator)
Python38\lib\site-packages\tensorflow\python\keras\engine\training.py:1468 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
Python38\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1259 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
Python38\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2730 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
Python38\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:3417 _call_for_each_replica
return fn(*args, **kwargs)
Python38\lib\site-packages\tensorflow\python\keras\engine\training.py:1461 run_step **
outputs = model.predict_step(data)
Python38\lib\site-packages\tensorflow\python\keras\engine\training.py:1434 predict_step
return self(x, training=False)
Python38\lib\site-packages\tensorflow\python\keras\engine\base_layer.py:998 __call__
input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
Python38\lib\site-packages\tensorflow\python\keras\engine\input_spec.py:271 assert_input_compatibility
raise ValueError('Input ' + str(input_index) +
ValueError: Input 0 is incompatible with layer sequential: expected shape=(None, 512, 512, 3), found shape=(32, 512, 3)
我有 Python 3.8.7 和 tensorflow 2.4.1 並使用 Pycharm 在 ZAEA234288CE0A39 中。
這個錯誤是什么意思,我該如何解決?
您忘記添加批次維度。 只需將以下轉換添加到ddecode_image
function:
image = tf.expand_dims(image, axis=0)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.