[英]How to deal with tf.saved_model.save(model, filepath) Error?
I use TFRS build a custom hybrid recommender;我使用 TFRS 构建了一个自定义的混合推荐器; the model was trained and is able to make prediction well. model 经过训练,能够很好地进行预测。 But when I save it:但是当我保存它时:
model = ...
filepath = 'saved_model/model0'
tf.saved_model.save(model, filepath)
error:错误:
--------------------------------------------------------------------------- FailedPreconditionError Traceback (most recent call last) /tmp/ipykernel_2678112/2773535809.py in <module>
5 #!mkdir -p saved_model
6 filepath = 'saved_model/model0'
----> 7 tf.saved_model.save(model, filepath)
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/saved_model/save.py in save(obj, export_dir, signatures, options) 1278 # pylint: enable=line-too-long 1279 metrics.IncrementWriteApi(_SAVE_V2_LABEL)
-> 1280 save_and_return_nodes(obj, export_dir, signatures, options) 1281 metrics.IncrementWrite(write_version="2") 1282
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/saved_model/save.py in save_and_return_nodes(obj, export_dir, signatures, options, experimental_skip_checkpoint) 1313 1314 _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
-> 1315 _build_meta_graph(obj, signatures, options, meta_graph_def)) 1316 saved_model.saved_model_schema_version = ( 1317 constants.SAVED_MODEL_SCHEMA_VERSION)
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, signatures, options, meta_graph_def) 1485 1486 with save_context.save_context(options):
-> 1487 return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph_impl(obj, signatures, options, meta_graph_def) 1434 1435 # Use _SaveableView to provide a frozen listing of properties and functions.
-> 1436 saveable_view = _SaveableView(checkpoint_graph_view, options, 1437 wrapped_functions) 1438 object_saver = util.TrackableSaver(checkpoint_graph_view)
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/saved_model/save.py in __init__(self, checkpoint_view, options, wrapped_functions)
200 # Run through the nodes in the object graph first for side effects of
201 # creating variables.
--> 202 self._trace_all_concrete_functions()
203
204 (self._trackable_objects, self.node_paths, self._node_ids,
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/saved_model/save.py in _trace_all_concrete_functions(self)
304 for function in self.checkpoint_view.list_functions(obj).values():
305 if isinstance(function, def_function.Function):
--> 306 function._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access
307
308 @property
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in _list_all_concrete_functions_for_serialization(self) 1194 A list of instances of `ConcreteFunction`. 1195 """
-> 1196 concrete_functions = self._list_all_concrete_functions() 1197 seen_signatures = [] 1198 for concrete_function in concrete_functions:
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in _list_all_concrete_functions(self) 1176 """Returns all concrete functions.""" 1177 if self.input_signature is not None:
-> 1178 self.get_concrete_function() 1179 concrete_functions = [] 1180 # pylint: disable=protected-access
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs) 1257 def get_concrete_function(self, *args, **kwargs): 1258 # Implements GenericFunction.get_concrete_function.
-> 1259 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs) 1260 concrete._garbage_collector.release() # pylint: disable=protected-access 1261 return concrete
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs) 1237 if self._stateful_fn is None: 1238 initializers
= []
-> 1239 self._initialize(args, kwargs, add_initializers_to=initializers) 1240 self._initialize_uninitialized_variables(initializers) 1241
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
778 self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
779 self._concrete_stateful_fn = (
--> 780 self._stateful_fn._get_concrete_function_internal_garbage_collected(
# pylint: disable=protected-access
781 *args, **kwds))
782
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args,
**kwargs) 3155 args, kwargs = None, None 3156 with self._lock:
-> 3157 graph_function, _ = self._maybe_define_function(args, kwargs) 3158 return graph_function 3159
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs) 3555 3556 self._function_cache.missed.add(call_context_key)
-> 3557 graph_function = self._create_graph_function(args, kwargs) 3558 self._function_cache.primary[cache_key] = graph_function 3559
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes) 3390 arg_names = base_arg_names + missing_arg_names 3391 graph_function = ConcreteFunction(
-> 3392 func_graph_module.func_graph_from_py_func( 3393 self._name, 3394 self._python_function,
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses) 1141
_, original_func = tf_decorator.unwrap(python_func) 1142
-> 1143 func_outputs = python_func(*func_args, **func_kwargs) 1144 1145 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
670 # the function a weak reference to itself to avoid a reference cycle.
671 with OptionalXlaContext(compile_with_xla):
--> 672 out = weak_wrapped_fn().__wrapped__(*args, **kwds)
673 return out
674
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/training/tracking/tracking.py in _creator()
279 @def_function.function(input_signature=[], autograph=False)
280 def _creator():
--> 281 resource = self._create_resource()
282 return resource
283
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py in <lambda>()
241 # to have a weak reference to the Dataset to avoid creating
242 # reference cycles and making work for the garbage collector.
--> 243 lambda: weak_self._trace_variant_creation()()), # pylint: disable=unnecessary-lambda,protected-access
244 name="_variant_tracker")
245 self._graph_attr = ops.get_default_graph()
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py in _trace_variant_creation(self)
363 # pylint: disable=protected-access
364 graph_def = graph_pb2.GraphDef().FromString(
--> 365 self._as_serialized_graph(external_state_policy=options_lib
366 .ExternalStatePolicy.FAIL).numpy())
367 output_node_names = []
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
550 'in a future version' if date is None else ('after %s' % date),
551 instructions)
--> 552 return func(*args, **kwargs)
553
554 doc = _add_deprecated_arg_notice_to_docstring(
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py in _as_serialized_graph(self, allow_stateful, strip_device_assignment, external_state_policy)
302 if external_state_policy:
303 policy = external_state_policy.value
--> 304 return gen_dataset_ops.dataset_to_graph_v2(
305 self._variant_tensor,
306 external_state_policy=policy,
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/ops/gen_dataset_ops.py in dataset_to_graph_v2(input_dataset, external_state_policy, strip_device_assignment, name) 1158 return _result 1159 except _core._NotOkStatusException as e:
-> 1160 _ops.raise_from_not_ok_status(e, name) 1161 except _core._FallbackException: 1162 pass
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name) 7105 def raise_from_not_ok_status(e, name): 7106 e.message += (" name: " + name if name is not None else "")
-> 7107 raise core._status_to_exception(e) from None # pylint: disable=protected-access 7108 7109
FailedPreconditionError: Failed to serialize the input pipeline graph: ResourceGather is stateful. [Op:DatasetToGraphV2]
# original code does not work
tf.keras.models.save_model(model, filepath)
# new code: set parameter save_format='tf'
tf.keras.models.save_model(
model, filepath, overwrite=True, include_optimizer=True, save_format='tf',
signatures=None, options=None, save_traces=True
)
You can also use another format您也可以使用其他格式
# original code does not work
model.save(filepath)
# new code: set parameter save_format='tf'
model.save(filepath, save_format='tf')
I had the same problem in my multitask model (TFRS) and following this solution posted here resolved the issue.我在我的多任务 model (TFRS) 中遇到了同样的问题,按照这里发布的解决方案解决了这个问题。
The metrics used in the retrieval task is the issue.检索任务中使用的指标是问题所在。
After training and evaluating the model I run this code to save it:在训练和评估 model 之后,我运行以下代码来保存它:
model.retrieval_task = tfrs.tasks.Retrieval() # Removes the metrics.
model.compile()
model.save("./models/my_rec_model")
and when I need to perform inference, I use this:当我需要进行推理时,我使用这个:
model = tf.keras.models.load_model("./models/my_rec_model")
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.