简体   繁体   English

如何处理 tf.saved_model.save(model, filepath) 错误?

[英]How to deal with tf.saved_model.save(model, filepath) Error?

I use TFRS build a custom hybrid recommender;我使用 TFRS 构建了一个自定义的混合推荐器; the model was trained and is able to make prediction well. model 经过训练,能够很好地进行预测。 But when I save it:但是当我保存它时:

model = ...
filepath = 'saved_model/model0'
tf.saved_model.save(model, filepath)

error:错误:

--------------------------------------------------------------------------- FailedPreconditionError                   Traceback (most recent call last) /tmp/ipykernel_2678112/2773535809.py in <module>
      5 #!mkdir -p saved_model
      6 filepath = 'saved_model/model0'
----> 7 tf.saved_model.save(model, filepath)

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/saved_model/save.py in save(obj, export_dir, signatures, options)    1278   # pylint: enable=line-too-long    1279   metrics.IncrementWriteApi(_SAVE_V2_LABEL)
-> 1280   save_and_return_nodes(obj, export_dir, signatures, options)    1281   metrics.IncrementWrite(write_version="2")    1282 

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/saved_model/save.py in save_and_return_nodes(obj, export_dir, signatures, options, experimental_skip_checkpoint)    1313     1314   _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
-> 1315       _build_meta_graph(obj, signatures, options, meta_graph_def))    1316   saved_model.saved_model_schema_version = (  1317       constants.SAVED_MODEL_SCHEMA_VERSION)

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, signatures, options, meta_graph_def)    1485 1486   with save_context.save_context(options):
-> 1487     return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph_impl(obj, signatures, options, meta_graph_def)    1434     1435   # Use _SaveableView to provide a frozen listing of properties and functions.
-> 1436   saveable_view = _SaveableView(checkpoint_graph_view, options,    1437                                 wrapped_functions)    1438   object_saver = util.TrackableSaver(checkpoint_graph_view)

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/saved_model/save.py in __init__(self, checkpoint_view, options, wrapped_functions)
    200     # Run through the nodes in the object graph first for side effects of
    201     # creating variables.
--> 202     self._trace_all_concrete_functions()
    203 
    204     (self._trackable_objects, self.node_paths, self._node_ids,

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/saved_model/save.py in _trace_all_concrete_functions(self)
    304       for function in self.checkpoint_view.list_functions(obj).values():
    305         if isinstance(function, def_function.Function):
--> 306           function._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
    307 
    308   @property

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in _list_all_concrete_functions_for_serialization(self)    1194       A list of instances of `ConcreteFunction`.    1195     """
-> 1196     concrete_functions = self._list_all_concrete_functions()    1197     seen_signatures = []    1198     for concrete_function in concrete_functions:

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in _list_all_concrete_functions(self)    1176     """Returns all concrete functions."""    1177     if self.input_signature is not None:
-> 1178       self.get_concrete_function()    1179     concrete_functions = []    1180     # pylint: disable=protected-access

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)    1257   def get_concrete_function(self, *args, **kwargs):    1258     # Implements GenericFunction.get_concrete_function.
-> 1259     concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)    1260 concrete._garbage_collector.release()  # pylint: disable=protected-access    1261     return concrete

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)    1237       if self._stateful_fn is None:    1238         initializers
= []
-> 1239         self._initialize(args, kwargs, add_initializers_to=initializers)    1240         self._initialize_uninitialized_variables(initializers)    1241 

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    778     self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
    779     self._concrete_stateful_fn = (
--> 780         self._stateful_fn._get_concrete_function_internal_garbage_collected( 
# pylint: disable=protected-access
    781             *args, **kwds))
    782 

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args,
**kwargs)    3155       args, kwargs = None, None    3156     with self._lock:
-> 3157       graph_function, _ = self._maybe_define_function(args, kwargs)    3158     return graph_function    3159 

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)    3555     3556         self._function_cache.missed.add(call_context_key)
-> 3557           graph_function = self._create_graph_function(args, kwargs)    3558           self._function_cache.primary[cache_key] = graph_function    3559 

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)    3390     arg_names = base_arg_names + missing_arg_names    3391     graph_function = ConcreteFunction(
-> 3392         func_graph_module.func_graph_from_py_func(    3393             self._name,    3394             self._python_function,

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)    1141   
_, original_func = tf_decorator.unwrap(python_func)    1142 
-> 1143       func_outputs = python_func(*func_args, **func_kwargs)    1144     1145       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    670         # the function a weak reference to itself to avoid a reference cycle.
    671         with OptionalXlaContext(compile_with_xla):
--> 672           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    673         return out
    674 

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/training/tracking/tracking.py in _creator()
    279     @def_function.function(input_signature=[], autograph=False)
    280     def _creator():
--> 281       resource = self._create_resource()
    282       return resource
    283 

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py in <lambda>()
    241             # to have a weak reference to the Dataset to avoid creating
    242             # reference cycles and making work for the garbage collector.
--> 243             lambda: weak_self._trace_variant_creation()()),  # pylint: disable=unnecessary-lambda,protected-access
    244         name="_variant_tracker")
    245     self._graph_attr = ops.get_default_graph()

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py in _trace_variant_creation(self)
    363       # pylint: disable=protected-access
    364       graph_def = graph_pb2.GraphDef().FromString(
--> 365           self._as_serialized_graph(external_state_policy=options_lib
    366                                     .ExternalStatePolicy.FAIL).numpy())
    367     output_node_names = []

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    550                 'in a future version' if date is None else ('after %s' % date),
    551                 instructions)
--> 552       return func(*args, **kwargs)
    553 
    554     doc = _add_deprecated_arg_notice_to_docstring(

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py in _as_serialized_graph(self, allow_stateful, strip_device_assignment, external_state_policy)
    302     if external_state_policy:
    303       policy = external_state_policy.value
--> 304       return gen_dataset_ops.dataset_to_graph_v2(
    305           self._variant_tensor,
    306           external_state_policy=policy,

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/ops/gen_dataset_ops.py in dataset_to_graph_v2(input_dataset, external_state_policy, strip_device_assignment, name)    1158       return _result    1159    except _core._NotOkStatusException as e:
-> 1160       _ops.raise_from_not_ok_status(e, name)    1161     except _core._FallbackException:    1162       pass

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)    7105 def raise_from_not_ok_status(e, name):    7106   e.message += (" name: " + name if name is not None else "")
-> 7107   raise core._status_to_exception(e) from None  # pylint: disable=protected-access    7108     7109 

FailedPreconditionError: Failed to serialize the input pipeline graph: ResourceGather is stateful. [Op:DatasetToGraphV2]
# original code does not work
tf.keras.models.save_model(model, filepath)

# new code: set parameter save_format='tf'
tf.keras.models.save_model(
    model, filepath, overwrite=True, include_optimizer=True, save_format='tf',
    signatures=None, options=None, save_traces=True
)

You can also use another format您也可以使用其他格式

# original code does not work
model.save(filepath)
# new code: set parameter save_format='tf'
model.save(filepath, save_format='tf')

I had the same problem in my multitask model (TFRS) and following this solution posted here resolved the issue.我在我的多任务 model (TFRS) 中遇到了同样的问题,按照这里发布的解决方案解决了这个问题。

The metrics used in the retrieval task is the issue.检索任务中使用的指标是问题所在。

After training and evaluating the model I run this code to save it:在训练和评估 model 之后,我运行以下代码来保存它:

model.retrieval_task = tfrs.tasks.Retrieval()  # Removes the metrics.
model.compile()
model.save("./models/my_rec_model")

and when I need to perform inference, I use this:当我需要进行推理时,我使用这个:

model = tf.keras.models.load_model("./models/my_rec_model")

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM