簡體   English   中英

如何將 TensorRT SavedModel 加載到 TensorFlow Estimator?

[英]How to load a TensorRT SavedModel to a TensorFlow Estimator?

我正在使用 TensorFlow 1.14,並將 TensorFlow SavedModel 加載到 Estimator,以下代碼適用於我:

estimator = tf.contrib.estimator.SavedModelEstimator(saved_model_dir)
prediction_results = estimator.predict(input_fn)

但是,當我使用 TensorRT 將 TensorFlow SavedModel 轉換為 TensorRT SavedModel 時,它返回一條錯誤消息:

ValueError: Directory provided has an invalid SavedModel format: saved_models/nvidia_fp16_converted

進一步看了一下,貌似問題是TensorRT在SavedModel目錄下沒有生成任何變量信息(包括variables.index),導致了上面的錯誤。 有誰知道如何解決這個問題?

對於任何感興趣的人,以下是我自己想出的解決方案: 通常,可以使用以下方法將 TF SavedModel 加載到 Estimator:

estimator = tf.contrib.estimator.SavedModelEstimator(SAVED_MODEL_DIR)

但是,在加載 TensorRT SavedModel 時會發生錯誤,因為 TensorRT 將所有變量都轉換為常量,因此 SavedModel 目錄中沒有變量信息(例如,沒有 variables.index)→ 由於 Estimator 嘗試加載變量而發生錯誤。 解決問題的步驟:

  • 我們需要轉到文件: "/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 2330, in _get_saved_model_ckpt並注釋掉對 variables.index 的檢查
"""if not gfile.Exists(os.path.join(saved_model_utils.get_variables_dir(saved_model_dir),
compat.as_text('variables.index'))):
raise ValueError('Directory provided has an invalid SavedModel format: %s'
% saved_model_dir)"""
  • 轉到文件: "/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/canned/saved_model_estimator.py", line 145, in __init__checkpoint_utils.list_variables(checkpoint)]並進行更改,以便 Estimator不嘗試從 SavedModel 加載變量:
"""checkpoint = estimator_lib._get_saved_model_ckpt(saved_model_dir) # pylint: disable=protected-access
vars_to_warm_start = [name for name, _ in
checkpoint_utils.list_variables(checkpoint)]
warm_start_settings = estimator_lib.WarmStartSettings(
ckpt_to_initialize_from=checkpoint,
vars_to_warm_start=vars_to_warm_start)"""
warm_start_settings = estimator_lib.WarmStartSettings(ckpt_to_initialize_from = estimator_lib._get_saved_model_ckpt(saved_model_dir))
  • 轉到文件: "/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/canned/saved_model_estimator.py", line 256, in _model_fn_from_saved_modeltraining_util.assert_global_step(global_step_tensor)並注釋掉global_step”,以防模型是從 NVIDIA 示例生成的(因此未進行任何訓練且未設置“global_step”):
#global_step_tensor = training_util.get_global_step(g)
#training_util.assert_global_step(global_step_tensor)
  • 轉到文件: "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/checkpoint_utils.py", line 291, in init_from_checkpoint init_from_checkpoint_fn)然后將 return 放在 init_from_checkpoint 函數的開頭,這樣它不會嘗試加載檢查點:
def _init_from_checkpoint(ckpt_dir_or_file, assignment_map):
"""See `init_from_checkpoint` for documentation."""
return

完成上述所有更改后,加載過程應該沒問題。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM