繁体   English   中英

将 Audioset ckpt 转换为 pb 文件

[英]Converting Audioset ckpt to pb file

我正在使用Audioset/VGGish并尝试将它们提供的检查点文件转换为 .pb 文件。 问题是为训练模型提供的仅有两个项目是 ckpt 文件(上面链接)和npz 文件

这是我尝试此操作的第三次尝试,现在已经花了几个小时试图找到执行此操作的最佳工具。 到目前为止,我已经找到了几种解决方案,但它们似乎都需要比 ckpt 文件更多的信息。 请记住,ckpt 文件和 Audioset 通常需要使用 TensorFlow < 2。


例子:

freeze_graph :我总是以ValueError: You need to supply the name of a node to --output_node_names错误结束ValueError: You need to supply the name of a node to --output_node_names无论我在那里输入什么值, ValueError: You need to supply the name of a node to --output_node_names 该示例使用softmax但问题是我似乎无法弄清楚如何从 ckpt 文件中提取节点名称,因此似乎我无法在不知道它们的情况下添加有效值。

记录的 GitHub 问题:按照 OP 的代码,我收到错误ValueError: No variables to save

Stack Overflow 问题:这似乎是一个可靠的答案,但 GitHub 存储库不提供.ckpt.meta文件。 我认为在某些情况下通常需要元信息? 我抬头看看是否有任何方法可以从 ckpt 文件中提取元来创建元文件,然后运行信息,因为元文件似乎是没有值的 ckpt 文件的结构或图形(来自这个答案: Tensorflow : .ckpt 文件和 .ckpt.meta 和 .ckpt.index 和 .pb 文件之间的关系是什么)但我可能会误解这一点。

我认为有一种方法可以提取元文件的原因之一是有人登录到 MMdnn GitHub 上的这个问题: Convert Audioset VGG from tensorflow to pytorch 虽然没有转换为 .pb,但他们的命令中有一个 ckpt.meta 文件。 该文件在他们的描述中没有链接,谷歌搜索“vggish_model.ckpt.meta”只会发现 GitHub 问题。 我已就此问题向 OP 发送消息,以查看他们是否可以说明该文件的来源。

上一篇(2018 年)带有转换脚本的文章:这是一篇相对较旧的文章。 我可以让脚本运行,但也会得到错误ValueError: No variables to save


如果有人能指出我正确的方向,那就太好了; 我已经开始用尽我的选择。 似乎有一些我正在尝试的好的解决方案,但我可能只是错过了一两步(或一两个文件)才能成功转换。

谢谢你的帮助!

我希望这个回复还不算太晚,但我设法使用存储库中提供的推理代码生成了 .pb 文件。

Obs:由于我的 GPU,我使用 tensorflow 1.4.1,所以这可能不适用于较新的版本,或者需要进行一些更改。

推理演示将图形和检查点数据加载到会话中。 从那里我可以使用一个函数来保存会话和图表。 这是我的代码示例:

import vggish_input
from tensorflow.python.tools import freeze_graph
def save(sess, directory, filename, saver):
    """
    This function saves a checkpoint, based on the current session
    """
    if not os.path.exists(directory):
        os.makedirs(directory)
    filepath = os.path.join(directory, filename)
    saver.save(sess, filepath)
    return filepath

def save_as_pb(sess, directory, filename, saver):
    """
    This function saves a checkpoint, then writes the graph in a pbtxt, and then              makes a frozen graph with the chekpoint and the pbtxt
    """

    # Save checkpoint to freeze graph later
    ckpt_filepath = save(sess, directory=directory, filename=filename, saver=saver)
    pbtxt_filename = filename + '.pbtxt'
    pbtxt_filepath = os.path.join(directory, pbtxt_filename)
    pb_filepath = os.path.join(directory, filename + '.pb')

    # This will only save the graph but the variables will not be saved.
    tf.train.write_graph(graph_or_graph_def=sess.graph_def, logdir=directory, name=pbtxt_filename, as_text=True)

    # Freeze graph, combining the checkpoint and 
    freeze_graph.freeze_graph(input_graph=pbtxt_filepath, input_saver='', input_binary=False, input_checkpoint=ckpt_filepath, output_node_names=vggish_params.OUTPUT_TENSOR_NAME.split(':')[0], restore_op_name='save/restore_all', filename_tensor_name='save/Const:0', output_graph=pb_filepath, clear_devices=True, initializer_nodes='')

    return pb_filepath

然后我在从vggish_inference_demo.py文件中的检查点加载模型后立即插入了save_as_pb

  config = tf.ConfigProto()
  config.gpu_options.allow_growth=True
  with tf.Graph().as_default(), tf.Session(config=config) as sess:
    # Define the model in inference mode, load the checkpoint, and
    # locate input and output tensors.
    vggish_slim.define_vggish_slim(training=False)
    vggish_slim.load_vggish_slim_checkpoint(sess, checkpoint)
    features_tensor = sess.graph.get_tensor_by_name(
        vggish_params.INPUT_TENSOR_NAME)
    embedding_tensor = sess.graph.get_tensor_by_name(
        vggish_params.OUTPUT_TENSOR_NAME)
    saver = tf.train.Saver()
    save_as_pb(sess, './saved_vggish/', 'vggish', saver)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM