繁体   English   中英

如何不在 Tensorflow 中重新初始化预训练的加载模型?

[英]How to not re-initialize the pretrained loaded model in Tensorflow?

我使用以下代码加载了一个预训练模型( Model 1 ):

def load_seq2seq_model(sess):


    with open(os.path.join(seq2seq_config_dir_path, 'config.pkl'), 'rb') as f:
        saved_args = pickle.load(f)

    # Initialize the model with saved args
    model = Model1(saved_args)

    #Inititalize Tensorflow saver
    saver = tf.train.Saver()

    # Checkpoint 
    ckpt = tf.train.get_checkpoint_state(seq2seq_config_dir_path)
    print('Loading model: ', ckpt.model_checkpoint_path)

    # Restore the model at the checkpoint
    saver.restore(sess, ckpt.model_checkpoint_path)
    return model

现在,我想从头开始训练另一个模型( Model 2 ),它将采用Model 1的输出。 但是为此我需要定义一个会话并加载预训练模型并初始化模型tf.initialize_all_variables() 因此,预训练的模型也将被初始化。

谁能告诉我如何火车Model 2从预先训练的模型得到的输出后, Model 1正常吗?

我正在尝试的内容如下-

with tf.Session() as sess:
    # Initialize all the variables of the graph
    seq2seq_model = load_seq2seq_model(sess)
    sess.run(tf.initialize_all_variables())
    .... Rest of the training code goes here....

使用保护程序恢复的所有变量都不需要初始化。 因此,您可以使用tf.variables_initializer(var_list)来仅初始化第二个网络的权重,而不是使用tf.initialize_all_variables()

要获取第二个网络的所有权重列表,您可以在可变范围内创建Model 2网络:

with tf.variable_scope("model2"):
    model2 = Model2(...)

然后使用

model_2_variables_list = tf.get_collection(
    tf.GraphKeys.GLOBAL_VARIABLES, 
    scope="model2"
)

获取Model 2网络的变量列表。 最后,您可以为第二个网络创建初始化程序:

init2 = tf.variables_initializer(model_2_variables_list)

with tf.Session() as sess:
    # Initialize all the variables of the graph
    seq2seq_model = load_seq2seq_model(sess)
    sess.run(init2)
    .... Rest of the training code goes here....

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM