簡體   English   中英

Tensorflow:保存和恢復變量問題

[英]Tensorflow: save and restore variable issue

如何在tensorflow中保存和恢復變量?

我遇到了問題。 我的代碼:

import tensorflow as tf

v1 = tf.Variable(tf.zeros([2, 2], dtype=tf.float32, name='v1'))
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(v1)
    save_path = saver.save(sess, 'model.ckpt')
    print "model saved in file:", save_path
    v1 = v1 + 1
    print sess.run(v1)
    saver = tf.train.import_meta_graph('model.ckpt.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    print sess.run(v1)

結果:

[[ 0.  0.]
 [ 0.  0.]]

[[ 1.  1.]
 [ 1.  1.]]

[[ 1.  1.]
 [ 1.  1.]]

我希望得到:

[[ 0.  0.]
 [ 0.  0.]]

[[ 1.  1.]
 [ 1.  1.]]

[[ 0.  0.]
 [ 0.  0.]]

我犯了什么錯誤?

請幫我理解。

您的代碼中有兩個主要問題:

  1. v1 = v1 + 1創建一個新的TensorFlow Tensor並將其綁定到Python變量v1 ,但不會更改您使用名稱"v1"創建的TensorFlow Variable的值。 因此,當您稍后調用sess.run(v1) ,您正在評估將原始變量加1的新張量,而不是從張量中讀取值。

    相反,要將變量添加到變量,您應該使用以下內容:

     increment_op = v1.assign_add(tf.ones([2, 2])) sess.run(increment_op) 
  2. tf.train.import_meta_graph()調用重新創建原始圖形,並在此過程中向圖形中添加新節點,包括新的tf.train.Saver 當您尚未構建圖形(或者沒有程序可用於執行該圖形)時,它非常有用。 由於您已經構建了圖形,因此只需要使用saver.restore(sess, 'model.ckpt')

以下程序應該產生您預期的行為:

import tensorflow as tf

v1 = tf.Variable(tf.zeros([2, 2], dtype=tf.float32, name='v1'))
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(v1)
    save_path = saver.save(sess, './model.ckpt')
    print "model saved in file:", save_path

    # Create an op to increment v1, run it, and print the result.   
    increment_op = v1.assign_add(tf.ones([2, 2]))
    sess.run(increment_op)
    print sess.run(v1)

    # Restore from the checkpoint saved above.
    saver.restore(sess, './model.ckpt')
    print sess.run(v1)

雖然,所選答案告訴我們應該做什么但不解釋為什么你得到了意想不到的答案。 我正在為稍后來這里的人解釋。

在Tensorflow中,如果您已經有一個圖表,並且在保存之后再次導入相同的圖形,則不會替換圖形操作,而是Tensorflow旨在通過添加后綴如_1,_2等來創建新變量。例如,你的情況,在你做之前:saver = tf.train.import_meta_graph('model.ckpt.meta')saver.restore(sess,tf.train.latest_checkpoint('。/'))你的圖有一個叫做v1的變量。 導入相同的圖形后,您的變量v1將不會被替換,而是將新的變量v1_1添加到圖形中。 因此,圖表的大小將加倍。 由於v1未通過加載圖形而改變,因此您仍然可以獲得v1的舊值(全1)。

如果要重置圖形,則必須在再次導入圖形之前使用tf.reset_default_graph(),如文檔中所述。 如果你在此之后導入並打印v1,你將得到一個全0 v1。

文件可能會對此有所了解。 我用一兩個修改運行你的文件:

import tensorflow as tf

v1 = tf.Variable(tf.zeros([2, 2], dtype=tf.float32, name='v1'))
saver = tf.train.Saver()

tf.add_to_collection('v1', v1)

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print sess.run(v1)
    save_path = saver.save(sess, 'model.ckpt')
    print "model saved in file:", save_path
    v1 = v1 + 1
    print sess.run(v1)
    saver = tf.train.import_meta_graph('model.ckpt.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    print sess.run(v1)

注意tf.add_to_collection調用。 之后,我跑了這個:

import tensorflow as tf

sess = tf.Session()
saver = tf.train.import_meta_graph('model.ckpt.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
print sess.run(tf.get_collection('v1')[0])

隨着輸出:

[[ 0.  0.]
 [ 0.  0.]]

看起來恢復的東西實際上不會修改你當前的計算圖,你需要使用集合來獲得你想要的東西。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM