繁体   English   中英

tensorflow C++ 相当于 tf.trainable_variables()?

[英]tensorflow C++ equivalent of tf.trainable_variables()?

我的目标是从 C++ API 中获取包含所有可训练变量的名称列表。 在 Python 中,这将与 tf.trainable_variables() 相关。

到目前为止,我尝试了这种方法。 我有一个 tensorflow::GraphDef 对象,我可以看到像这样创建的所有节点:

for (int i = 0; i < graphDef.node_size(); i++) {
    graphDef.node(i).PrintDebugString();
}

这很棒。 其中一些节点指的是可训练变量,但我不知道如何获取该信息/或者是否可能。

该信息在GraphDef对象中不可用。 tf.trainable_variables只返回键为tf.GraphKeys.TRAINABLE_VARIABLES的图形集合,但图形集合不会保存到GraphDef ,只保存到MetaGraphDef (请参阅导出和导入 MetaGraph )。 如果您想从 C++ 访问保存的图中的可训练变量,您必须导出和导入 MetaGraph,或者使用一致的命名方案来区分它们。

请注意,顺便说一下,图集合将在 TensorFlow 2.x 中被弃用。 有关更多信息,请参阅弃用集合

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM