[英]How to restore multiple neural network models in TensorFlow?
我正在設計一個具有3個簡單前饋NN的集成神經網絡。 現在,我面臨着恢復這三個神經網絡以進行測試的問題。 到目前為止,通過saver
函數創建並保存了3個NN模型。
saver = tf.train.Saver()
saver.save(sess, save_path=get_save_path(i), global_step=1000)
我已成功將它們保存到“ .checkpoint”,“。meta”,“。index”和“ .data”文件中,如下所示。
我試圖通過使用以下編碼來還原它們:
saver = tf.train.import_meta_graph(get_save_path(i) + '-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint(save_dir))
但是它只恢復了第三個NN,即network2
進行測試。 這影響了我的結果,因為該算法僅采用1個模型( network2
),並假設所有三個network2
模型的集合函數都相同。
僅供參考:
我理想的合奏功能:
ensemble = (network0 + network1 + network2) / 3
實際結果:
ensemble = (network2 + network2 + network2) / 3
如何使TF一起還原所有3個NN模型?
我想你把事情弄混了。 但是,讓我先回答一個問題:
您將需要在不同的范圍內多次創建模型。 然后應該可以對這些變量求平均值。
假設您通過以下方式創建了3個網絡
import tensorflow as tf
# save 3 version
for i in range(3):
tf.reset_default_graph()
a = tf.get_variable('test', [1])
assign_op = a.assign([i])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(assign_op)
print a.name, sess.run(a)
saver = tf.train.Saver(tf.global_variables())
saver.save(sess, './model/version_%i' % i)
在這里,每個網絡都具有相同的圖結構,並且僅包含一個參數/權重名稱“ test”。
然后,您將需要多次創建相同的圖,但是要使用不同的 variable_scopes,例如
# load all versions in different scopes
tf.reset_default_graph()
a_collection = []
for i in range(3):
# use different var-scopes
with tf.variable_scope('scope_%0i' % i):
# create same network
a = tf.get_variable('test', [1])
a_collection.append(a)
現在,每個還原器都需要知道應該使用哪個作用域或變量名映射。 這可以通過
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print zip(sess.run(a_collection), [n.name for n in a_collection])
for i in range(3):
loader = tf.train.Saver({"test": a_collection[i]})
loader = loader.restore(sess, './model/version_%i' % i)
print sess.run(a_collection)
哪個會給你
[array([0.], dtype=float32), array([1.], dtype=float32), array([2.], dtype=float32)]
如預期的那樣。 現在,您可以對模型進行任何操作。
但這不是整體預測的工作原理! 在集成模型中,通常只對預測取平均。 因此,您可能會使用不同的模型多次運行腳本,然后平均預測值。
如果您真的想平均模型的權重,請考慮使用numpy將權重作為python-dict轉儲。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.