[英]How to not re-initialize the pretrained loaded model in Tensorflow?
我使用以下代碼加載了一個預訓練模型( Model 1
):
def load_seq2seq_model(sess):
with open(os.path.join(seq2seq_config_dir_path, 'config.pkl'), 'rb') as f:
saved_args = pickle.load(f)
# Initialize the model with saved args
model = Model1(saved_args)
#Inititalize Tensorflow saver
saver = tf.train.Saver()
# Checkpoint
ckpt = tf.train.get_checkpoint_state(seq2seq_config_dir_path)
print('Loading model: ', ckpt.model_checkpoint_path)
# Restore the model at the checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
return model
現在,我想從頭開始訓練另一個模型( Model 2
),它將采用Model 1
的輸出。 但是為此我需要定義一個會話並加載預訓練模型並初始化模型tf.initialize_all_variables()
。 因此,預訓練的模型也將被初始化。
誰能告訴我如何火車Model 2
從預先訓練的模型得到的輸出后, Model 1
正常嗎?
我正在嘗試的內容如下-
with tf.Session() as sess:
# Initialize all the variables of the graph
seq2seq_model = load_seq2seq_model(sess)
sess.run(tf.initialize_all_variables())
.... Rest of the training code goes here....
使用保護程序恢復的所有變量都不需要初始化。 因此,您可以使用tf.variables_initializer(var_list)
來僅初始化第二個網絡的權重,而不是使用tf.initialize_all_variables()
。
要獲取第二個網絡的所有權重列表,您可以在可變范圍內創建Model 2
網絡:
with tf.variable_scope("model2"):
model2 = Model2(...)
然后使用
model_2_variables_list = tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES,
scope="model2"
)
獲取Model 2
網絡的變量列表。 最后,您可以為第二個網絡創建初始化程序:
init2 = tf.variables_initializer(model_2_variables_list)
with tf.Session() as sess:
# Initialize all the variables of the graph
seq2seq_model = load_seq2seq_model(sess)
sess.run(init2)
.... Rest of the training code goes here....
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.