简体   繁体   English

如何将 TensorFlow 检查点文件转换为 TensorFlowJS?

[英]How to convert TensorFlow checkpoint files to TensorFlowJS?

I have a project that was developed on TensorFlow v1 I think.我认为我有一个在 TensorFlow v1 上开发的项目。 It works in Python 3.8 like this:它在 Python 3.8 中工作,如下所示:

 ...
 saver = tf.train.Saver(var_list=vars)
 ...
 saver.restore(self.sess, tf.train.latest_checkpoint(checkpoint_dir))
 ...

The checkpoint files reside in the "checkpoint_dir"检查点文件位于“checkpoint_dir”中

I would like to use this with TFjs but I can't figure out how to transform the checkpoint files to something that can be loaded with TFjs.我想将它与 TFjs 一起使用,但我不知道如何将检查点文件转换为可以用 TFjs 加载的东西。

What should I do?我应该怎么办?

thanks,谢谢,

John约翰

Ok, I figured it out.好的,我想通了。 Hope this helps other beginners like me too.希望这对像我这样的其他初学者也有帮助。

The checkpoint files do not contain the model, they only contain the values (weights, etc) of the model.检查点文件不包含 model,它们仅包含 model 的值(权重等)。

The model is actually built in the code. model 实际上是内置在代码中的。 So, here are the steps to convert the Tensorflow v1 checkpoint files to TensorflowJS loadable model:因此,以下是将 Tensorflow v1 检查点文件转换为 TensorflowJS 可加载 model 的步骤:

  1. First I saved the checkpoint again because there was a file that was missing (.meta file) This contains some meta information about the values in the checkpoint.首先,我再次保存了检查点,因为缺少一个文件(.meta 文件),其中包含有关检查点中值的一些元信息。 To save the checkpoint with meta I used this code right after the saver.restore(... call like this:为了使用 meta 保存检查点,我在saver.restore(...调用之后立即使用了此代码,如下所示:
...
saver.save(self.sess,save_path='./newcheckpoint/')
...
  1. Save the model as a frozen model file like this:将 model 保存为冻结的 model 文件,如下所示:
import tensorflow.compat.v1 as tf

meta_path = './newcheckpoint/.meta' # Your .meta file
output_node_names = ['name_of_the_output_node']    # Output nodes

with tf.Session() as sess:
    # Restore the graph
    saver = tf.train.import_meta_graph(meta_path)

    # Load weights
    saver.restore(sess,tf.train.latest_checkpoint('./newcheckpoint/'))

    # Freeze the graph
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names)

    # Save the frozen graph
    with open('./freeze/output_graph.pb', 'wb') as f:
      f.write(frozen_graph_def.SerializeToString())

This will save the model to ./freeze/output_graph.pb这会将 model 保存到./freeze/output_graph.pb

  1. Using tensorflowjs_converter convert the frozen model to a web model like this:使用tensorflowjs_converter将冻结的 model 转换为 web model,如下所示:

tensorflowjs_converter --input_format=tf_frozen_model --output_node_names='final_add' --skip_op_check./freeze/output_graph.pb./web_model/

Had to use the --skip_op_check due to some missing op errors/warnings when trying to convert.由于尝试转换时缺少一些操作错误/警告,不得不使用--skip_op_check

As a result of step 3, the ./webmodel/ folder will contain the JSON and binary files required by the TensorflowJS library.作为第 3 步的结果,. ./webmodel/文件夹将包含 TensorflowJS 库所需的 JSON 和二进制文件。

Here's how I load the model using tfjs 2.x:这是我使用 tfjs 2.x 加载 model 的方法:

model=await tf.loadGraphModel('web_model/model.json');

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM