简体   繁体   中英

tensorflow C++ equivalent of tf.trainable_variables()?

My goal is to get a list of names with all the trainable variables from the C++ API. In Python this would be down with tf.trainable_variables().

So far I tried this approach. I have a tensorflow::GraphDef object and I can see all nodes that have been created like this:

for (int i = 0; i < graphDef.node_size(); i++) {
    graphDef.node(i).PrintDebugString();
}

which is great. Some of those nodes refer to trainable variables, but I don't know how do I get that information / or if it's possible.

That information is not available in the GraphDef object. tf.trainable_variables just returns the graph collection with key tf.GraphKeys.TRAINABLE_VARIABLES , but graph collections are not saved to the GraphDef , only to the MetaGraphDef (see Exporting and Importing a MetaGraph ). If you want to access trainable variables in a saved graph from C++, you have to either export and import the MetaGraph instead or, maybe, use a consistent naming scheme to differentiate them.

Note, by the way, that graph collections will be deprecated in TensorFlow 2.x. See Deprecating collections for more information.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM