簡體   English   中英

如何從python中的.pb文件恢復Tensorflow模型?

[英]How to restore Tensorflow model from .pb file in python?

我有一個 tensorflow .pb 文件,我想將其加載到 python DNN 中,恢復圖形並獲得預測。 我這樣做是為了測試創建的 .pb 文件是否可以做出類似於普通 Saver.save() 模型的預測。

我的基本問題是,當我使用上述 .pb 文件在 Android 上進行預測時,得到的預測值非常不同

我的 .pb 文件創建代碼:

frozen_graph = tf.graph_util.convert_variables_to_constants(
        session,
        session.graph_def,
        ['outputLayer/Softmax']
    )
with open('frozen_model.pb', 'wb') as f:
  f.write(frozen_graph.SerializeToString())

所以我有兩個主要問題:

  1. 如何將上述 .pb 文件加載到 python Tensorflow 模型?
  2. 為什么我在 python 和 android 中得到完全不同的預測值?

以下代碼將讀取模型並打印出圖中節點的名稱。

import tensorflow as tf
from tensorflow.python.platform import gfile
GRAPH_PB_PATH = './frozen_model.pb'
with tf.Session() as sess:
   print("load graph")
   with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
       graph_def = tf.GraphDef()
   graph_def.ParseFromString(f.read())
   sess.graph.as_default()
   tf.import_graph_def(graph_def, name='')
   graph_nodes=[n for n in graph_def.node]
   names = []
   for t in graph_nodes:
      names.append(t.name)
   print(names)

您正確地凍結了圖形,這就是為什么您得到不同結果的原因基本上權重沒有存儲在您的模型中。 您可以使用freeze_graph.py鏈接)獲取正確存儲的圖形。

這是 tensorflow 2 的更新代碼。

import tensorflow as tf

GRAPH_PB_PATH = './frozen_model.pb'
with tf.compat.v1.Session() as sess:
   print("load graph")
   with tf.io.gfile.GFile(GRAPH_PB_PATH,'rb') as f:
       graph_def = tf.compat.v1.GraphDef()
   graph_def.ParseFromString(f.read())
   sess.graph.as_default()
   tf.import_graph_def(graph_def, name='')
   graph_nodes=[n for n in graph_def.node]
   names = []
   for t in graph_nodes:
      names.append(t.name)
   print(names)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM