簡體   English   中英

如何在TensorFlow中還原多個神經網絡模型?

[英]How to restore multiple neural network models in TensorFlow?

我正在設計一個具有3個簡單前饋NN的集成神經網絡。 現在,我面臨着恢復這三個神經網絡以進行測試的問題。 到目前為止,通過saver函數創建並保存了3個NN模型。

saver = tf.train.Saver()    
saver.save(sess, save_path=get_save_path(i), global_step=1000)

我已成功將它們保存到“ .checkpoint”,“。meta”,“。index”和“ .data”文件中,如下所示。

在此處輸入圖片說明

我試圖通過使用以下編碼來還原它們:

 saver = tf.train.import_meta_graph(get_save_path(i) + '-1000.meta')
 saver.restore(sess,tf.train.latest_checkpoint(save_dir))

但是它只恢復了第三個NN,即network2進行測試。 這影響了我的結果,因為該算法僅采用1個模型( network2 ),並假設所有三個network2模型的集合函數都相同。

僅供參考:

我理想的合奏功能:

ensemble = (network0 + network1 + network2) / 3

實際結果:

ensemble = (network2 + network2 + network2) / 3

如何使TF一起還原所有3個NN模型?

我想你把事情弄混了。 但是,讓我先回答一個問題:

您將需要在不同的范圍內多次創建模型。 然后應該可以對這些變量求平均值。

假設您通過以下方式創建了3個網絡

import tensorflow as tf

# save 3 version
for i in range(3):
    tf.reset_default_graph()

    a = tf.get_variable('test', [1])

    assign_op = a.assign([i])

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(assign_op)
        print a.name, sess.run(a)

        saver = tf.train.Saver(tf.global_variables())
        saver.save(sess, './model/version_%i' % i)

在這里,每個網絡都具有相同的圖結構,並且僅包含一個參數/權重名稱“ test”。

然后,您將需要多次創建相同的圖,但是要使用不同的 variable_scopes,例如

# load all versions in different scopes
tf.reset_default_graph()

a_collection = []

for i in range(3):
    # use different var-scopes
    with tf.variable_scope('scope_%0i' % i):
        # create same network
        a = tf.get_variable('test', [1])
        a_collection.append(a)

現在,每個還原器都需要知道應該使用哪個作用域或變量名映射。 這可以通過

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    print zip(sess.run(a_collection), [n.name for n in a_collection])

    for i in range(3):
        loader = tf.train.Saver({"test": a_collection[i]})
        loader = loader.restore(sess, './model/version_%i' % i)

    print sess.run(a_collection)

哪個會給你

 [array([0.], dtype=float32), array([1.], dtype=float32), array([2.], dtype=float32)]

如預期的那樣。 現在,您可以對模型進行任何操作。

但這不是整體預測的工作原理! 在集成模型中,通常對預測取平均。 因此,您可能會使用不同的模型多次運行腳本,然后平均預測值。

如果您真的想平均模型的權重,請考慮使用numpy將權重作為python-dict轉儲。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM