[英]How to not re-initialize the pretrained loaded model in Tensorflow?
我使用以下代码加载了一个预训练模型( Model 1
):
def load_seq2seq_model(sess):
with open(os.path.join(seq2seq_config_dir_path, 'config.pkl'), 'rb') as f:
saved_args = pickle.load(f)
# Initialize the model with saved args
model = Model1(saved_args)
#Inititalize Tensorflow saver
saver = tf.train.Saver()
# Checkpoint
ckpt = tf.train.get_checkpoint_state(seq2seq_config_dir_path)
print('Loading model: ', ckpt.model_checkpoint_path)
# Restore the model at the checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
return model
现在,我想从头开始训练另一个模型( Model 2
),它将采用Model 1
的输出。 但是为此我需要定义一个会话并加载预训练模型并初始化模型tf.initialize_all_variables()
。 因此,预训练的模型也将被初始化。
谁能告诉我如何火车Model 2
从预先训练的模型得到的输出后, Model 1
正常吗?
我正在尝试的内容如下-
with tf.Session() as sess:
# Initialize all the variables of the graph
seq2seq_model = load_seq2seq_model(sess)
sess.run(tf.initialize_all_variables())
.... Rest of the training code goes here....
使用保护程序恢复的所有变量都不需要初始化。 因此,您可以使用tf.variables_initializer(var_list)
来仅初始化第二个网络的权重,而不是使用tf.initialize_all_variables()
。
要获取第二个网络的所有权重列表,您可以在可变范围内创建Model 2
网络:
with tf.variable_scope("model2"):
model2 = Model2(...)
然后使用
model_2_variables_list = tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES,
scope="model2"
)
获取Model 2
网络的变量列表。 最后,您可以为第二个网络创建初始化程序:
init2 = tf.variables_initializer(model_2_variables_list)
with tf.Session() as sess:
# Initialize all the variables of the graph
seq2seq_model = load_seq2seq_model(sess)
sess.run(init2)
.... Rest of the training code goes here....
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.