[英]List of tensor names in graph in Tensorflow
论文没有准确反映模型。 如果你从 arxiv 下载源代码,它有一个准确的模型描述作为 model.txt,其中的名称与发布的模型中的名称密切相关。
要回答您的第一个问题, sess.graph.get_operations()
为您提供了一个操作列表。 对于操作, op.name
为您提供名称, op.values()
为您提供它生成的张量列表(在 inception-v3 模型中,所有张量名称都是附加了“:0”的操作名称,所以pool_3:0
是最终池化操作产生的张量。)
以上答案都是正确的。 对于上述任务,我遇到了一个易于理解/简单的代码。 所以在这里分享:-
import tensorflow as tf
def printTensors(pb_file):
# read pb into graph_def
with tf.gfile.GFile(pb_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# import graph_def
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
# print operations
for op in graph.get_operations():
print(op.name)
printTensors("path-to-my-pbfile.pb")
查看图中的操作(你会看到很多,所以为了简短起见,我在这里只给出了第一个字符串)。
sess = tf.Session()
op = sess.graph.get_operations()
[m.values() for m in op][1]
out:
(<tf.Tensor 'conv1/weights:0' shape=(4, 4, 3, 32) dtype=float32_ref>,)
您甚至不必创建会话即可查看图中所有操作名称的名称。 为此,您只需要获取默认图形tf.get_default_graph()
并提取所有操作: .get_operations
。 每个操作都有许多字段,您需要的是名称。
这是代码:
import tensorflow as tf
a = tf.Variable(5)
b = tf.Variable(6)
c = tf.Variable(7)
d = (a + b) * c
for i in tf.get_default_graph().get_operations():
print i.name
作为嵌套列表理解:
tensor_names = [t.name for op in tf.get_default_graph().get_operations() for t in op.values()]
获取图形中张量名称的函数(默认为默认图形):
def get_names(graph=tf.get_default_graph()):
return [t.name for op in graph.get_operations() for t in op.values()]
在图中获取张量的函数(默认为默认图):
def get_tensors(graph=tf.get_default_graph()):
return [t for op in graph.get_operations() for t in op.values()]
saved_model_cli
是 TF 附带的替代命令行工具,如果您处理“SavedModel”格式,它可能会很有用。 从文档
!saved_model_cli show --dir /tmp/mobilenet/1 --tag_set serve --all
此输出可能很有用,例如:
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['__saved_model_init_op']:
The given SavedModel SignatureDef contains the following input(s):
The given SavedModel SignatureDef contains the following output(s):
outputs['__saved_model_init_op'] tensor_info:
dtype: DT_INVALID
shape: unknown_rank
name: NoOp
Method name is:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['dense_input'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1280)
name: serving_default_dense_input:0
The given SavedModel SignatureDef contains the following output(s):
outputs['dense_1'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.