簡體   English   中英

在Tensorflow中還原已保存的神經網絡

[英]Restore a saved neural network in Tensorflow

在將我的問題標記為重復之前,我想讓您了解我已經經歷了很多問題,但是那里的所有解決方案都無法消除我的疑問並解決我的問題。 我有一個要保存的受過訓練的神經網絡,以后使用此模型針對測試數據集對該模型進行測試。

我嘗試保存和還原它,但是沒有得到預期的結果。 恢復似乎沒有用,也許我使用錯誤,只是使用了全局變量初始值設定項給定的值。

這是我用於保存模型的代碼。

 sess.run(tf.initializers.global_variables())
#num_epochs = 7
for epoch in range(num_epochs):  
  start_time = time.time()
  train_accuracy = 0
  train_loss = 0
  val_loss = 0
  val_accuracy = 0

  for bid in range(int(train_data_size/batch_size)):
     X_train_batch = X_train[bid*batch_size:(bid+1)*batch_size]
     y_train_batch = y_train[bid*batch_size:(bid+1)*batch_size]
     sess.run(optimizer, feed_dict = {x:X_train_batch, y:y_train_batch,prob:0.50})  

     train_accuracy = train_accuracy + sess.run(model_accuracy, feed_dict={x : X_train_batch,y:y_train_batch,prob:0.50})
     train_loss = train_loss + sess.run(loss_value, feed_dict={x : X_train_batch,y:y_train_batch,prob:0.50})

  for bid in range(int(val_data_size/batch_size)):
     X_val_batch = X_val[bid*batch_size:(bid+1)*batch_size]
     y_val_batch = y_val[bid*batch_size:(bid+1)*batch_size]
     val_accuracy = val_accuracy + sess.run(model_accuracy,feed_dict = {x:X_val_batch, y:y_val_batch,prob:0.75})
     val_loss = val_loss + sess.run(loss_value, feed_dict = {x:X_val_batch, y:y_val_batch,prob:0.75})

  train_accuracy = train_accuracy/int(train_data_size/batch_size)
  val_accuracy = val_accuracy/int(val_data_size/batch_size)
  train_loss = train_loss/int(train_data_size/batch_size)
  val_loss = val_loss/int(val_data_size/batch_size)


  end_time = time.time()


  saver.save(sess,'./blood_model_x_v2',global_step = epoch)  

保存模型后,文件將被寫入我的工作目錄中,如下所示。

blood_model_x_v2-2.data-0000-的-0001
blood_model_x_v2-2.index
blood_model_x_v2-2.meta

同樣,使用v2-3,以此類推,直到v2-6,然后是“ checkpoint”文件。 然后,我嘗試使用此代碼段(在初始化后)恢復它,但與預期的結果不同。 我究竟做錯了什么 ?

saver = tf.train.import_meta_graph('blood_model_x_v2-5.meta')
saver.restore(test_session,tf.train.latest_checkpoint('./'))

根據tensorflow docs

恢復恢復以前保存的變量。

此方法運行由構造函數添加的用於還原變量的操作。 它需要一個啟動圖形的會話。 要恢復變量不必初始化 ,因為恢復本身就是初始化變量的一種方法。

讓我們來看一個例子:

我們保存類似於以下內容的模型:

import tensorflow as tf

# Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1 = tf.Variable(2.0, name="bias")
feed_dict = {w1: 4, w2: 8}

# Define a test operation that we will restore
w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, b1, name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Create a saver object which will save all the variables
saver = tf.train.Saver()

# Run the operation by feeding input
print (sess.run(w4, feed_dict))
# Prints 24 which is sum of (w1+w2)*b1

# Now, save the graph
saver.save(sess, './ckpnt/my_test_model', global_step=1000)

然后使用以下內容加載訓練后的模型:

import tensorflow as tf

sess = tf.Session()
# First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('./ckpnt/my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./ckpnt'))

# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}

# Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print (sess.run(op_to_restore, feed_dict))
# This will print 60 which is calculated
# using new values of w1 and w2 and saved value of b1.

如您所見,我們沒有在恢復部分中初始化會話。 使用Checkpoint可以更好地保存和還原模型,這使您可以檢查模型是否正確還原。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM