繁体   English   中英

如何将 TensorFlow 检查点文件转换为 TensorFlowJS?

[英]How to convert TensorFlow checkpoint files to TensorFlowJS?

我认为我有一个在 TensorFlow v1 上开发的项目。 它在 Python 3.8 中工作,如下所示:

 ...
 saver = tf.train.Saver(var_list=vars)
 ...
 saver.restore(self.sess, tf.train.latest_checkpoint(checkpoint_dir))
 ...

检查点文件位于“checkpoint_dir”中

我想将它与 TFjs 一起使用,但我不知道如何将检查点文件转换为可以用 TFjs 加载的东西。

我应该怎么办?

谢谢,

约翰

好的,我想通了。 希望这对像我这样的其他初学者也有帮助。

检查点文件不包含 model,它们仅包含 model 的值(权重等)。

model 实际上是内置在代码中的。 因此,以下是将 Tensorflow v1 检查点文件转换为 TensorflowJS 可加载 model 的步骤:

  1. 首先,我再次保存了检查点,因为缺少一个文件(.meta 文件),其中包含有关检查点中值的一些元信息。 为了使用 meta 保存检查点,我在saver.restore(...调用之后立即使用了此代码,如下所示:
...
saver.save(self.sess,save_path='./newcheckpoint/')
...
  1. 将 model 保存为冻结的 model 文件,如下所示:
import tensorflow.compat.v1 as tf

meta_path = './newcheckpoint/.meta' # Your .meta file
output_node_names = ['name_of_the_output_node']    # Output nodes

with tf.Session() as sess:
    # Restore the graph
    saver = tf.train.import_meta_graph(meta_path)

    # Load weights
    saver.restore(sess,tf.train.latest_checkpoint('./newcheckpoint/'))

    # Freeze the graph
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names)

    # Save the frozen graph
    with open('./freeze/output_graph.pb', 'wb') as f:
      f.write(frozen_graph_def.SerializeToString())

这会将 model 保存到./freeze/output_graph.pb

  1. 使用tensorflowjs_converter将冻结的 model 转换为 web model,如下所示:

tensorflowjs_converter --input_format=tf_frozen_model --output_node_names='final_add' --skip_op_check./freeze/output_graph.pb./web_model/

由于尝试转换时缺少一些操作错误/警告,不得不使用--skip_op_check

作为第 3 步的结果,. ./webmodel/文件夹将包含 TensorflowJS 库所需的 JSON 和二进制文件。

这是我使用 tfjs 2.x 加载 model 的方法:

model=await tf.loadGraphModel('web_model/model.json');

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM