繁体   English   中英

如何在 Tensorflow SavedModel 中列出所有使用过的操作?

[英]How to list all used operations in Tensorflow SavedModel?

如果我使用tensorflow.saved_model.save函数以 SavedModel 格式保存我的模型,那么我如何检索此模型中使用了哪些 Tensorflow Ops。 由于模型可以恢复,所以这些操作都存储在图中,我猜是在saved_model.pb文件中。 如果我加载这个 protobuf(所以不是整个模型),protobuf 的库部分会列出这些,但目前没有记录并标记为实验功能。 在 Tensorflow 1.x 中创建的模型将没有这部分。

那么,从 SavedModel 格式的模型中检索已使用操作列表(如MatchingFilesWriteFile )的快速可靠方法是什么?

现在我可以冻结整个事情,就像tensorflowjs-converter一样。 因为他们还检查支持的操作。 当 LSTM 在模型中时,这当前不起作用,请参见此处 有没有更好的方法来做到这一点,因为 Ops 肯定在那里?

示例模型:

class FileReader(tf.Module):

@tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
def read_disk(self, file_name):
    input_scalar = tf.reshape(file_name, [])
    output = tf.io.read_file(input_scalar)
    return tf.stack([output], name='content')

file_reader = FileReader()

tf.saved_model.save(file_reader, 'file_reader')

预期输出所有操作,在这种情况下至少包含:

如果saved_model.pbSavedModel protobuf 消息,那么您可以直接从那里获取操作。 假设我们创建一个模型如下:

import tensorflow as tf

class FileReader(tf.Module):
    @tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
    def read_disk(self, file_name):
        input_scalar = tf.reshape(file_name, [])
        output = tf.io.read_file(input_scalar)
        return tf.stack([output], name='content')

file_reader = FileReader()
tf.saved_model.save(file_reader, 'tmp')

我们现在可以找到该模型使用的操作,如下所示:

from tensorflow.core.protobuf.saved_model_pb2 import SavedModel

saved_model = SavedModel()
with open('tmp/saved_model.pb', 'rb') as f:
    saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one
for meta_graph in saved_model.meta_graphs:
    # Add operations in the graph definition
    model_op_names.update(node.op for node in meta_graph.graph_def.node)
    # Go through the functions in the graph definition
    for func in meta_graph.graph_def.library.function:
        # Add operations in each function
        model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
print(*model_op_names, sep='\n')
# Const
# Identity
# MergeV2Checkpoints
# NoOp
# Pack
# PartitionedCall
# Placeholder
# ReadFile
# Reshape
# RestoreV2
# SaveV2
# ShardedFilename
# StatefulPartitionedCall
# StringJoin

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM