簡體   English   中英

tensorflow C++ 相當於 tf.trainable_variables()?

[英]tensorflow C++ equivalent of tf.trainable_variables()?

我的目標是從 C++ API 中獲取包含所有可訓練變量的名稱列表。 在 Python 中,這將與 tf.trainable_variables() 相關。

到目前為止,我嘗試了這種方法。 我有一個 tensorflow::GraphDef 對象,我可以看到像這樣創建的所有節點:

for (int i = 0; i < graphDef.node_size(); i++) {
    graphDef.node(i).PrintDebugString();
}

這很棒。 其中一些節點指的是可訓練變量,但我不知道如何獲取該信息/或者是否可能。

該信息在GraphDef對象中不可用。 tf.trainable_variables只返回鍵為tf.GraphKeys.TRAINABLE_VARIABLES的圖形集合,但圖形集合不會保存到GraphDef ,只保存到MetaGraphDef (請參閱導出和導入 MetaGraph )。 如果您想從 C++ 訪問保存的圖中的可訓練變量,您必須導出和導入 MetaGraph,或者使用一致的命名方案來區分它們。

請注意,順便說一下,圖集合將在 TensorFlow 2.x 中被棄用。 有關更多信息,請參閱棄用集合

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM