简体   繁体   中英

Loading a saved Tensorflow model from its .meta file

I am trying to load a tensorflow meta graph from a saved checkpoint using Tensorflow version 1.15 to convert it to a SavedModel for tensorflow serving. It is a Speech Recognition Model with Local attention and unidirectional LSTM implemented using the Returnn Toolkit with Tensorflow Backend. I am using the following code.

import tensorflow as tf
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
import sys

if len(sys.argv)!=2:
        print("Usage:" + sys.argv[0] + "save_dir")
        exit(1)
export_dir=sys.argv[1]
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir)
sigs={}
with tf.Session(graph=tf.Graph()) as sess:
        new_saver=tf.train.import_meta_graph("./serv_test/model.238.meta")
        new_saver.restore(sess, tf.train.latest_checkpoint("./serv_test"))
        graph=tf.get_default_graph()
        input_audio=graph.get_tensor_by_name('inference/default/wav:0')
        output_hyps=graph.get_tensor_by_name('inference/default/Reshape_7:0')
        sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = tf.saved_model.signature_def_utils.predict_signature_def({"in":input_audio},{"out":output_hyps})
        builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING], signature_def_map=sigs,)
builder.save()

But I am getting the following error in the import_meta_graph line:

Traceback (most recent call last):
  File "xport.py", line 16, in <module>
    new_saver=tf.train.import_meta_graph("./serv_test/model.238.meta")
  File "/home/ubuntu/tf1.15/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1453, in import_meta_graph
    **kwargs)[0]
  File "/home/ubuntu/tf1.15/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1477, in _import_meta_graph_with_return_elements
    **kwargs))
  File "/home/ubuntu/tf1.15/lib/python3.6/site-packages/tensorflow_core/python/framework/meta_graph.py", line 809, in import_scoped_meta_graph_with_return_elements
    return_elements=return_elements)
  File "/home/ubuntu/tf1.15/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/ubuntu/tf1.15/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 405, in import_graph_def
    producer_op_list=producer_op_list)
  File "/home/ubuntu/tf1.15/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 501, in _import_graph_def_internal
    graph._c_graph, serialized, options)  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.NotFoundError: Op type not registered
 'NativeLstm2' in binary running on ip-10-1-21-241. Make sure the Op and Kernel
 are registered in the binary running in this process. Note that if you are loading a
 saved graph which used ops from tf.contrib, accessing (e.g.) `tf.contrib.resampler`
 should be done before importing the graph, as contrib ops are lazily registered when
 the module is first accessed.

Is there any way to get around this error? Is it because of the custom built layers used in Returnn? Is there any way to make a Returnn Model tensorflow servable? Thanks.

You should remove the graph=tf.Graph() , otherwise your import_meta_graph will import it into the wrong graph. Just see some official TF examples how to use import_meta_graph .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM