简体   繁体   English

如何从python中的.pb文件恢复Tensorflow模型?

[英]How to restore Tensorflow model from .pb file in python?

I have an tensorflow .pb file which I would like to load into python DNN, restore the graph and get the predictions.我有一个 tensorflow .pb 文件,我想将其加载到 python DNN 中,恢复图形并获得预测。 I am doing this to test out whether the .pb file created can make the predictions similar to the normal Saver.save() model.我这样做是为了测试创建的 .pb 文件是否可以做出类似于普通 Saver.save() 模型的预测。

My basic problem is am getting a very different value of predictions when I make them on Android using the above mentioned .pb file我的基本问题是,当我使用上述 .pb 文件在 Android 上进行预测时,得到的预测值非常不同

My .pb file creation code:我的 .pb 文件创建代码:

frozen_graph = tf.graph_util.convert_variables_to_constants(
        session,
        session.graph_def,
        ['outputLayer/Softmax']
    )
with open('frozen_model.pb', 'wb') as f:
  f.write(frozen_graph.SerializeToString())

So I have two major concerns:所以我有两个主要问题:

  1. How can I load the above mentioned .pb file to python Tensorflow model ?如何将上述 .pb 文件加载到 python Tensorflow 模型?
  2. Why am I getting completely different values of prediction in python and android ?为什么我在 python 和 android 中得到完全不同的预测值?

The following code will read the model and print out the names of the nodes in the graph.以下代码将读取模型并打印出图中节点的名称。

import tensorflow as tf
from tensorflow.python.platform import gfile
GRAPH_PB_PATH = './frozen_model.pb'
with tf.Session() as sess:
   print("load graph")
   with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
       graph_def = tf.GraphDef()
   graph_def.ParseFromString(f.read())
   sess.graph.as_default()
   tf.import_graph_def(graph_def, name='')
   graph_nodes=[n for n in graph_def.node]
   names = []
   for t in graph_nodes:
      names.append(t.name)
   print(names)

You are freezing the graph properly that is why you are getting different results basically weights are not getting stored in your model.您正确地冻结了图形,这就是为什么您得到不同结果的原因基本上权重没有存储在您的模型中。 You can use the freeze_graph.py ( link ) for getting a correctly stored graph.您可以使用freeze_graph.py链接)获取正确存储的图形。

Here is the updated code for tensorflow 2.这是 tensorflow 2 的更新代码。

import tensorflow as tf

GRAPH_PB_PATH = './frozen_model.pb'
with tf.compat.v1.Session() as sess:
   print("load graph")
   with tf.io.gfile.GFile(GRAPH_PB_PATH,'rb') as f:
       graph_def = tf.compat.v1.GraphDef()
   graph_def.ParseFromString(f.read())
   sess.graph.as_default()
   tf.import_graph_def(graph_def, name='')
   graph_nodes=[n for n in graph_def.node]
   names = []
   for t in graph_nodes:
      names.append(t.name)
   print(names)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM