簡體   English   中英

保存重新訓練的張量流模型的問題

[英]Issue saving a retrained tensorflow model

我正在嘗試加載模型(以前已保存),並在重新訓練后保存它。 加載效果很好,但是保存時遇到了如下問題:

sess=tf.Session()
sess.run(init)
loader = tf.train.import_meta_graph(self.model_path+'.meta')
loader.restore(sess,self.model_path)#tf.train.latest_checkpoint('./'))            
print('Model restored')
#retrain
saver=tf.train.Saver()
saver.save(sess, self.model_path)

我不會在第一次保存時遇到任何類似的問題,如下所示:

saver=tf.train.Saver()
sess=tf.Session()
sess.run(init)
#train
saver.save(sess, self.model_path)

我遇到的錯誤是:

File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1139, in __init__
    self.build()
  File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1170, in build
    restore_sequentially=self._restore_sequentially)
  File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 673, in build
    saveables = self._ValidateAndSliceInputs(names_to_saveables)
  File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 557, in _ValidateAndSliceInputs
    names_to_saveables = BaseSaverBuilder.OpListToDict(names_to_saveables)
  File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 535, in OpListToDict
    name)
ValueError: At least two variables have the same name: Variable_15/Adam

您看到此消息是因為作用域中有兩個名稱相同的變量。 tf.train.import_meta_graph從文件中讀取一個圖,並將所有操作和張量添加到當前現有圖。 令我驚訝的是, import_meta_graph甚至都沒有觸發過這樣的異常。

請參見完整示例以重現此行為:

import tensorflow as tf

# tiny graph
x = tf.placeholder(tf.float32, shape=[1, 2], name='input')
output = tf.identity(tf.layers.dense(x, 1), name='output')
cost = tf.reduce_sum(x * output)
# create first time u'beta1_power:0', u'beta2_power:0'
train_op = tf.train.AdamOptimizer().minimize(cost)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(tf.global_variables())
    saver.save(sess, './adam/my_model')

    print([v.name for v in tf.global_variables()])

    # create second time u'beta1_power:0', u'beta2_power:0'
    meta_graph = tf.train.import_meta_graph('./adam/my_model.meta')
    meta_graph.restore(sess, './adam/my_model')

    print([v.name for v in tf.global_variables()])

    saver = tf.train.Saver(tf.global_variables())
    # exception as there are now two times: u'beta1_power:0', u'beta2_power:0'
    saver.save(sess, './adam/my_model2')

一個解決方案是

  • 清除利用圖tf.reset_default_graph()之前tf.trainimport_meta_graph
  • tf.train.import_meta_graph使用新會話
  • 只需使用tf.train.Saver().restore(sess, '/tmp/model/my_model')加載權重

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM