簡體   English   中英

繼續在 SavedModel 上訓練或從 SavedModel 加載檢查點

[英]Continue training on SavedModel or load checkpoint from SavedModel

在 tensorflow 1.14 中,很明顯tf.compat.v1.train.init_from_checkpoint可以加載ckpt以繼續訓練(或熱啟動)。 但是,我在SavedModel中找不到任何相應的方法,並且tf.estimator.WarmStartSetting 也僅支持ckpt 這對我來說很奇怪,因為這個答案提到應該有一個檢查點存儲在SavedModel中。 有人知道嗎:

  1. 如何在 SavedModel 中加載檢查點? 或者
  2. 如何在 SavedModel 上熱啟動訓練?

為了加載 SavedModel 以繼續訓練,您可以使用 tf.saved_model.loader.load,如下所示:

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
  tf.saved_model.loader.load(sess, [tag_constants.SERVING], saved_model_location)

為了提供新的輸入數據,您可以獲得輸入張量名稱,如下所示:

signature_def = meta_graph_def.signature_def[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
inputs = [v.name for v in signature_def.inputs.values()]
input_tensors = [node.split(":")[0] for node in inputs]

然后你可以制作一些feed_dict來為輸入張量提供新的輸入。 獲取 output 張量的方法與我上面概述的方法類似。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM