簡體   English   中英

如何處理 tf.saved_model.save(model, filepath) 錯誤?

[英]How to deal with tf.saved_model.save(model, filepath) Error?

我使用 TFRS 構建了一個自定義的混合推薦器; model 經過訓練,能夠很好地進行預測。 但是當我保存它時:

model = ...
filepath = 'saved_model/model0'
tf.saved_model.save(model, filepath)

錯誤:

--------------------------------------------------------------------------- FailedPreconditionError                   Traceback (most recent call last) /tmp/ipykernel_2678112/2773535809.py in <module>
      5 #!mkdir -p saved_model
      6 filepath = 'saved_model/model0'
----> 7 tf.saved_model.save(model, filepath)

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/saved_model/save.py in save(obj, export_dir, signatures, options)    1278   # pylint: enable=line-too-long    1279   metrics.IncrementWriteApi(_SAVE_V2_LABEL)
-> 1280   save_and_return_nodes(obj, export_dir, signatures, options)    1281   metrics.IncrementWrite(write_version="2")    1282 

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/saved_model/save.py in save_and_return_nodes(obj, export_dir, signatures, options, experimental_skip_checkpoint)    1313     1314   _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
-> 1315       _build_meta_graph(obj, signatures, options, meta_graph_def))    1316   saved_model.saved_model_schema_version = (  1317       constants.SAVED_MODEL_SCHEMA_VERSION)

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, signatures, options, meta_graph_def)    1485 1486   with save_context.save_context(options):
-> 1487     return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph_impl(obj, signatures, options, meta_graph_def)    1434     1435   # Use _SaveableView to provide a frozen listing of properties and functions.
-> 1436   saveable_view = _SaveableView(checkpoint_graph_view, options,    1437                                 wrapped_functions)    1438   object_saver = util.TrackableSaver(checkpoint_graph_view)

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/saved_model/save.py in __init__(self, checkpoint_view, options, wrapped_functions)
    200     # Run through the nodes in the object graph first for side effects of
    201     # creating variables.
--> 202     self._trace_all_concrete_functions()
    203 
    204     (self._trackable_objects, self.node_paths, self._node_ids,

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/saved_model/save.py in _trace_all_concrete_functions(self)
    304       for function in self.checkpoint_view.list_functions(obj).values():
    305         if isinstance(function, def_function.Function):
--> 306           function._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
    307 
    308   @property

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in _list_all_concrete_functions_for_serialization(self)    1194       A list of instances of `ConcreteFunction`.    1195     """
-> 1196     concrete_functions = self._list_all_concrete_functions()    1197     seen_signatures = []    1198     for concrete_function in concrete_functions:

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in _list_all_concrete_functions(self)    1176     """Returns all concrete functions."""    1177     if self.input_signature is not None:
-> 1178       self.get_concrete_function()    1179     concrete_functions = []    1180     # pylint: disable=protected-access

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)    1257   def get_concrete_function(self, *args, **kwargs):    1258     # Implements GenericFunction.get_concrete_function.
-> 1259     concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)    1260 concrete._garbage_collector.release()  # pylint: disable=protected-access    1261     return concrete

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)    1237       if self._stateful_fn is None:    1238         initializers
= []
-> 1239         self._initialize(args, kwargs, add_initializers_to=initializers)    1240         self._initialize_uninitialized_variables(initializers)    1241 

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    778     self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
    779     self._concrete_stateful_fn = (
--> 780         self._stateful_fn._get_concrete_function_internal_garbage_collected( 
# pylint: disable=protected-access
    781             *args, **kwds))
    782 

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args,
**kwargs)    3155       args, kwargs = None, None    3156     with self._lock:
-> 3157       graph_function, _ = self._maybe_define_function(args, kwargs)    3158     return graph_function    3159 

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)    3555     3556         self._function_cache.missed.add(call_context_key)
-> 3557           graph_function = self._create_graph_function(args, kwargs)    3558           self._function_cache.primary[cache_key] = graph_function    3559 

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)    3390     arg_names = base_arg_names + missing_arg_names    3391     graph_function = ConcreteFunction(
-> 3392         func_graph_module.func_graph_from_py_func(    3393             self._name,    3394             self._python_function,

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)    1141   
_, original_func = tf_decorator.unwrap(python_func)    1142 
-> 1143       func_outputs = python_func(*func_args, **func_kwargs)    1144     1145       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    670         # the function a weak reference to itself to avoid a reference cycle.
    671         with OptionalXlaContext(compile_with_xla):
--> 672           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    673         return out
    674 

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/training/tracking/tracking.py in _creator()
    279     @def_function.function(input_signature=[], autograph=False)
    280     def _creator():
--> 281       resource = self._create_resource()
    282       return resource
    283 

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py in <lambda>()
    241             # to have a weak reference to the Dataset to avoid creating
    242             # reference cycles and making work for the garbage collector.
--> 243             lambda: weak_self._trace_variant_creation()()),  # pylint: disable=unnecessary-lambda,protected-access
    244         name="_variant_tracker")
    245     self._graph_attr = ops.get_default_graph()

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py in _trace_variant_creation(self)
    363       # pylint: disable=protected-access
    364       graph_def = graph_pb2.GraphDef().FromString(
--> 365           self._as_serialized_graph(external_state_policy=options_lib
    366                                     .ExternalStatePolicy.FAIL).numpy())
    367     output_node_names = []

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    550                 'in a future version' if date is None else ('after %s' % date),
    551                 instructions)
--> 552       return func(*args, **kwargs)
    553 
    554     doc = _add_deprecated_arg_notice_to_docstring(

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py in _as_serialized_graph(self, allow_stateful, strip_device_assignment, external_state_policy)
    302     if external_state_policy:
    303       policy = external_state_policy.value
--> 304       return gen_dataset_ops.dataset_to_graph_v2(
    305           self._variant_tensor,
    306           external_state_policy=policy,

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/ops/gen_dataset_ops.py in dataset_to_graph_v2(input_dataset, external_state_policy, strip_device_assignment, name)    1158       return _result    1159    except _core._NotOkStatusException as e:
-> 1160       _ops.raise_from_not_ok_status(e, name)    1161     except _core._FallbackException:    1162       pass

/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)    7105 def raise_from_not_ok_status(e, name):    7106   e.message += (" name: " + name if name is not None else "")
-> 7107   raise core._status_to_exception(e) from None  # pylint: disable=protected-access    7108     7109 

FailedPreconditionError: Failed to serialize the input pipeline graph: ResourceGather is stateful. [Op:DatasetToGraphV2]
# original code does not work
tf.keras.models.save_model(model, filepath)

# new code: set parameter save_format='tf'
tf.keras.models.save_model(
    model, filepath, overwrite=True, include_optimizer=True, save_format='tf',
    signatures=None, options=None, save_traces=True
)

您也可以使用其他格式

# original code does not work
model.save(filepath)
# new code: set parameter save_format='tf'
model.save(filepath, save_format='tf')

我在我的多任務 model (TFRS) 中遇到了同樣的問題,按照這里發布的解決方案解決了這個問題。

檢索任務中使用的指標是問題所在。

在訓練和評估 model 之后,我運行以下代碼來保存它:

model.retrieval_task = tfrs.tasks.Retrieval()  # Removes the metrics.
model.compile()
model.save("./models/my_rec_model")

當我需要進行推理時,我使用這個:

model = tf.keras.models.load_model("./models/my_rec_model")

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM