简体   繁体   English

如何在张量流图中得到每个节点的输入形状?

[英]how to get the shape of the input in each node in tensorflow graph?

Hi: Now I am working on converting a tensorflow checkpoint model into a caffe model. 嗨:现在我正在努力将张量流检查点模型转换为caffe模型。 I have succeded in reading the graph and have extracted the attr values in each node. 我已成功读取图形并已提取每个节点中的attr值。 I got the values of 'dilations', 'strides' and 'padding' attr in "Conv2D" node and the shapes in "weights" node, but I couldn't get the value of 'shape' attr, it's empty in Conv2D's input node. 我在“Conv2D”节点中获得了“dilations”,“strides”和“padding”attr的值以及“weight”节点中的形状,但是我无法获得'shape'attr的值,在Conv2D的输入中它是空的节点。 However, these shapes are shown in tensorboard's graphs. 但是,这些形状显示在张量板的图表中。 here is my code: 这是我的代码:

new_saver = tf.train.import_meta_graph(meta_path)          
new_saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))
graph_def = sess.graph_def
node_list = graph_def.node

# conv_node, weight_node, from_node are all in node_list
# conv_node: the conv2d node in graph_def
# weight_node: the weights node of conv2d
# from_node: the input feature map node of conv2d

weight_shape_attr = weight_node.attr['shape']
weight_shapes = [dim.size for dim in weight_shape_attr.shape.dim]

strides = [ii for ii in conv_node.attr['strides'].list.i]
dilations = [ii for ii in conv_node.attr['dilations'].list.i]

shapes = from_node.attr['shape']  # this is empty

and the tensorboard graph: tensorboard_graph 张量板图: tensorboard_graph

Note that the input of the Conv2D node has the shape of ?x79x79x32, it must have been stored somewhere in the model file. 请注意,Conv2D节点的输入形状为?x79x79x32,它必须存储在模型文件中的某个位置。 Can any one give some help? any hits will be helpful, thanks. 任何人都可以提供一些帮助吗?任何点击都会有所帮助,谢谢。

Tensorflow graphs have as_graph_def method that has optional parameter add_shapes ( False by default). Tensorflow图具有as_graph_def方法,该方法具有可选参数add_shapes (默认为False )。 If set to True it results in additional attribute of nodes: _output_shapes . 如果设置为True则会产生节点的其他属性: _output_shapes

So you can try getting GraphDef this way: 因此,您可以尝试以这种方式获取GraphDef:

graph_def = sess.graph.as_graph_def(add_shapes=True)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM