簡體   English   中英

如何不在 Tensorflow 中重新初始化預訓練的加載模型?

[英]How to not re-initialize the pretrained loaded model in Tensorflow?

我使用以下代碼加載了一個預訓練模型( Model 1 ):

def load_seq2seq_model(sess):


    with open(os.path.join(seq2seq_config_dir_path, 'config.pkl'), 'rb') as f:
        saved_args = pickle.load(f)

    # Initialize the model with saved args
    model = Model1(saved_args)

    #Inititalize Tensorflow saver
    saver = tf.train.Saver()

    # Checkpoint 
    ckpt = tf.train.get_checkpoint_state(seq2seq_config_dir_path)
    print('Loading model: ', ckpt.model_checkpoint_path)

    # Restore the model at the checkpoint
    saver.restore(sess, ckpt.model_checkpoint_path)
    return model

現在,我想從頭開始訓練另一個模型( Model 2 ),它將采用Model 1的輸出。 但是為此我需要定義一個會話並加載預訓練模型並初始化模型tf.initialize_all_variables() 因此,預訓練的模型也將被初始化。

誰能告訴我如何火車Model 2從預先訓練的模型得到的輸出后, Model 1正常嗎?

我正在嘗試的內容如下-

with tf.Session() as sess:
    # Initialize all the variables of the graph
    seq2seq_model = load_seq2seq_model(sess)
    sess.run(tf.initialize_all_variables())
    .... Rest of the training code goes here....

使用保護程序恢復的所有變量都不需要初始化。 因此,您可以使用tf.variables_initializer(var_list)來僅初始化第二個網絡的權重,而不是使用tf.initialize_all_variables()

要獲取第二個網絡的所有權重列表,您可以在可變范圍內創建Model 2網絡:

with tf.variable_scope("model2"):
    model2 = Model2(...)

然后使用

model_2_variables_list = tf.get_collection(
    tf.GraphKeys.GLOBAL_VARIABLES, 
    scope="model2"
)

獲取Model 2網絡的變量列表。 最后,您可以為第二個網絡創建初始化程序:

init2 = tf.variables_initializer(model_2_variables_list)

with tf.Session() as sess:
    # Initialize all the variables of the graph
    seq2seq_model = load_seq2seq_model(sess)
    sess.run(init2)
    .... Rest of the training code goes here....

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM