[英]How to load a TensorRT SavedModel to a TensorFlow Estimator?
I'm using TensorFlow 1.14, and to load a TensorFlow SavedModel to an Estimator the following code works for me:我正在使用 TensorFlow 1.14,并将 TensorFlow SavedModel 加载到 Estimator,以下代码适用于我:
estimator = tf.contrib.estimator.SavedModelEstimator(saved_model_dir)
prediction_results = estimator.predict(input_fn)
However, when I used TensorRT to convert the TensorFlow SavedModel to TensorRT SavedModel, it returns an error message:但是,当我使用 TensorRT 将 TensorFlow SavedModel 转换为 TensorRT SavedModel 时,它返回一条错误消息:
ValueError: Directory provided has an invalid SavedModel format: saved_models/nvidia_fp16_converted
I have looked at it further, and it looks like the problem is that TensorRT does not generate any variables information (including variables.index) in the SavedModel directory, which makes the above error.进一步看了一下,貌似问题是TensorRT在SavedModel目录下没有生成任何变量信息(包括variables.index),导致了上面的错误。 Does anyone know how to resolve this problem?有谁知道如何解决这个问题?
For anyone who is interested, below is the solution I have figured out by myself: Normally, a TF SavedModel can be loaded to Estimator using:对于任何感兴趣的人,以下是我自己想出的解决方案: 通常,可以使用以下方法将 TF SavedModel 加载到 Estimator:
estimator = tf.contrib.estimator.SavedModelEstimator(SAVED_MODEL_DIR)
However, when loading TensorRT SavedModel errors occur because TensorRT converts all variables to constants, thus there is no variables' information in the SavedModel directory (eg no variables.index) → errors occur since Estimator try to load variables.但是,在加载 TensorRT SavedModel 时会发生错误,因为 TensorRT 将所有变量都转换为常量,因此 SavedModel 目录中没有变量信息(例如,没有 variables.index)→ 由于 Estimator 尝试加载变量而发生错误。 Steps to fix the problem:解决问题的步骤:
"/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 2330, in _get_saved_model_ckpt
and comment out the check for variables.index我们需要转到文件: "/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 2330, in _get_saved_model_ckpt
并注释掉对 variables.index 的检查"""if not gfile.Exists(os.path.join(saved_model_utils.get_variables_dir(saved_model_dir),
compat.as_text('variables.index'))):
raise ValueError('Directory provided has an invalid SavedModel format: %s'
% saved_model_dir)"""
"/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/canned/saved_model_estimator.py", line 145, in __init__checkpoint_utils.list_variables(checkpoint)]
and make changes so that Estimator does not try to load variables from SavedModel:转到文件: "/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/canned/saved_model_estimator.py", line 145, in __init__checkpoint_utils.list_variables(checkpoint)]
并进行更改,以便 Estimator不尝试从 SavedModel 加载变量:"""checkpoint = estimator_lib._get_saved_model_ckpt(saved_model_dir) # pylint: disable=protected-access
vars_to_warm_start = [name for name, _ in
checkpoint_utils.list_variables(checkpoint)]
warm_start_settings = estimator_lib.WarmStartSettings(
ckpt_to_initialize_from=checkpoint,
vars_to_warm_start=vars_to_warm_start)"""
warm_start_settings = estimator_lib.WarmStartSettings(ckpt_to_initialize_from = estimator_lib._get_saved_model_ckpt(saved_model_dir))
"/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/canned/saved_model_estimator.py", line 256, in _model_fn_from_saved_modeltraining_util.assert_global_step(global_step_tensor)
and comment out the check for "global_step" in case the model is generated from NVIDIA examples (thus no training was done and "global_step" was not set):转到文件: "/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/canned/saved_model_estimator.py", line 256, in _model_fn_from_saved_modeltraining_util.assert_global_step(global_step_tensor)
并注释掉global_step”,以防模型是从 NVIDIA 示例生成的(因此未进行任何训练且未设置“global_step”):#global_step_tensor = training_util.get_global_step(g)
#training_util.assert_global_step(global_step_tensor)
"/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/checkpoint_utils.py", line 291, in init_from_checkpoint init_from_checkpoint_fn)
and just put return at the beginning of the init_from_checkpoint function so that it does not try to load the checkpoint:转到文件: "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/checkpoint_utils.py", line 291, in init_from_checkpoint init_from_checkpoint_fn)
然后将 return 放在 init_from_checkpoint 函数的开头,这样它不会尝试加载检查点:def _init_from_checkpoint(ckpt_dir_or_file, assignment_map):
"""See `init_from_checkpoint` for documentation."""
return
The loading process should be fine after all the above changes.完成上述所有更改后,加载过程应该没问题。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.