简体   繁体   中英

How to not re-initialize the pretrained loaded model in Tensorflow?

I have loaded a pretrained model ( Model 1 ) using the following code:

def load_seq2seq_model(sess):


    with open(os.path.join(seq2seq_config_dir_path, 'config.pkl'), 'rb') as f:
        saved_args = pickle.load(f)

    # Initialize the model with saved args
    model = Model1(saved_args)

    #Inititalize Tensorflow saver
    saver = tf.train.Saver()

    # Checkpoint 
    ckpt = tf.train.get_checkpoint_state(seq2seq_config_dir_path)
    print('Loading model: ', ckpt.model_checkpoint_path)

    # Restore the model at the checkpoint
    saver.restore(sess, ckpt.model_checkpoint_path)
    return model

Now, I want to train another model ( Model 2 ) from scratch which will take the output of the Model 1 . But for that I need to define a session and load the pre-trained model and initialize the model tf.initialize_all_variables() . So, the pre-trained model will also get initialized.

Can anyone please tell me how to train the Model 2 after getting the output from the pre-trained model Model 1 properly?

What I am trying is given below -

with tf.Session() as sess:
    # Initialize all the variables of the graph
    seq2seq_model = load_seq2seq_model(sess)
    sess.run(tf.initialize_all_variables())
    .... Rest of the training code goes here....

All variables that are restored using a saver don't need to be initialized. Therefore, instead of using tf.initialize_all_variables() you can use tf.variables_initializer(var_list) to only initialize the weights of the second network.

To get a list of all the weights of the second network you can create the Model 2 network in a variable scope:

with tf.variable_scope("model2"):
    model2 = Model2(...)

Then use

model_2_variables_list = tf.get_collection(
    tf.GraphKeys.GLOBAL_VARIABLES, 
    scope="model2"
)

to get the variable list of the Model 2 network. Finally you can create the initialisier for the second network:

init2 = tf.variables_initializer(model_2_variables_list)

with tf.Session() as sess:
    # Initialize all the variables of the graph
    seq2seq_model = load_seq2seq_model(sess)
    sess.run(init2)
    .... Rest of the training code goes here....

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM