簡體   English   中英

從 .meta 文件加載已保存的 Tensorflow 模型

[英]Loading a saved Tensorflow model from its .meta file

我正在嘗試使用 Tensorflow 1.15 版從保存的檢查點加載 tensorflow 元圖,以將其轉換為 SavedModel 以供 tensorflow 服務。 它是一個帶有局部注意力和單向 LSTM 的語音識別模型,使用帶有 Tensorflow 后端的 Returnn 工具包實現。 我正在使用以下代碼。

import tensorflow as tf
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
import sys

if len(sys.argv)!=2:
        print("Usage:" + sys.argv[0] + "save_dir")
        exit(1)
export_dir=sys.argv[1]
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir)
sigs={}
with tf.Session(graph=tf.Graph()) as sess:
        new_saver=tf.train.import_meta_graph("./serv_test/model.238.meta")
        new_saver.restore(sess, tf.train.latest_checkpoint("./serv_test"))
        graph=tf.get_default_graph()
        input_audio=graph.get_tensor_by_name('inference/default/wav:0')
        output_hyps=graph.get_tensor_by_name('inference/default/Reshape_7:0')
        sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = tf.saved_model.signature_def_utils.predict_signature_def({"in":input_audio},{"out":output_hyps})
        builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING], signature_def_map=sigs,)
builder.save()

但是我在import_meta_graph行中收到以下錯誤:

Traceback (most recent call last):
  File "xport.py", line 16, in <module>
    new_saver=tf.train.import_meta_graph("./serv_test/model.238.meta")
  File "/home/ubuntu/tf1.15/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1453, in import_meta_graph
    **kwargs)[0]
  File "/home/ubuntu/tf1.15/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1477, in _import_meta_graph_with_return_elements
    **kwargs))
  File "/home/ubuntu/tf1.15/lib/python3.6/site-packages/tensorflow_core/python/framework/meta_graph.py", line 809, in import_scoped_meta_graph_with_return_elements
    return_elements=return_elements)
  File "/home/ubuntu/tf1.15/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/ubuntu/tf1.15/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 405, in import_graph_def
    producer_op_list=producer_op_list)
  File "/home/ubuntu/tf1.15/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 501, in _import_graph_def_internal
    graph._c_graph, serialized, options)  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.NotFoundError: Op type not registered
 'NativeLstm2' in binary running on ip-10-1-21-241. Make sure the Op and Kernel
 are registered in the binary running in this process. Note that if you are loading a
 saved graph which used ops from tf.contrib, accessing (e.g.) `tf.contrib.resampler`
 should be done before importing the graph, as contrib ops are lazily registered when
 the module is first accessed.

有什么辦法可以解決這個錯誤嗎? 是不是因為 Returnn 中使用了自定義構建的層? 有沒有辦法使返回模型張量流可用? 謝謝。

您應該刪除graph=tf.Graph() ,否則您的import_meta_graph會將其導入到錯誤的圖表中。 只需查看一些官方 TF 示例如何使用import_meta_graph

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM