简体   繁体   中英

How to convert TensorFlow checkpoint files to TensorFlowJS?

I have a project that was developed on TensorFlow v1 I think. It works in Python 3.8 like this:

 ...
 saver = tf.train.Saver(var_list=vars)
 ...
 saver.restore(self.sess, tf.train.latest_checkpoint(checkpoint_dir))
 ...

The checkpoint files reside in the "checkpoint_dir"

I would like to use this with TFjs but I can't figure out how to transform the checkpoint files to something that can be loaded with TFjs.

What should I do?

thanks,

John

Ok, I figured it out. Hope this helps other beginners like me too.

The checkpoint files do not contain the model, they only contain the values (weights, etc) of the model.

The model is actually built in the code. So, here are the steps to convert the Tensorflow v1 checkpoint files to TensorflowJS loadable model:

  1. First I saved the checkpoint again because there was a file that was missing (.meta file) This contains some meta information about the values in the checkpoint. To save the checkpoint with meta I used this code right after the saver.restore(... call like this:
...
saver.save(self.sess,save_path='./newcheckpoint/')
...
  1. Save the model as a frozen model file like this:
import tensorflow.compat.v1 as tf

meta_path = './newcheckpoint/.meta' # Your .meta file
output_node_names = ['name_of_the_output_node']    # Output nodes

with tf.Session() as sess:
    # Restore the graph
    saver = tf.train.import_meta_graph(meta_path)

    # Load weights
    saver.restore(sess,tf.train.latest_checkpoint('./newcheckpoint/'))

    # Freeze the graph
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names)

    # Save the frozen graph
    with open('./freeze/output_graph.pb', 'wb') as f:
      f.write(frozen_graph_def.SerializeToString())

This will save the model to ./freeze/output_graph.pb

  1. Using tensorflowjs_converter convert the frozen model to a web model like this:

tensorflowjs_converter --input_format=tf_frozen_model --output_node_names='final_add' --skip_op_check./freeze/output_graph.pb./web_model/

Had to use the --skip_op_check due to some missing op errors/warnings when trying to convert.

As a result of step 3, the ./webmodel/ folder will contain the JSON and binary files required by the TensorflowJS library.

Here's how I load the model using tfjs 2.x:

model=await tf.loadGraphModel('web_model/model.json');

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM