I have loaded a pretrained model ( Model 1
) using the following code:
def load_seq2seq_model(sess):
with open(os.path.join(seq2seq_config_dir_path, 'config.pkl'), 'rb') as f:
saved_args = pickle.load(f)
# Initialize the model with saved args
model = Model1(saved_args)
#Inititalize Tensorflow saver
saver = tf.train.Saver()
# Checkpoint
ckpt = tf.train.get_checkpoint_state(seq2seq_config_dir_path)
print('Loading model: ', ckpt.model_checkpoint_path)
# Restore the model at the checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
return model
Now, I want to train another model ( Model 2
) from scratch which will take the output of the Model 1
. But for that I need to define a session and load the pre-trained model and initialize the model tf.initialize_all_variables()
. So, the pre-trained model will also get initialized.
Can anyone please tell me how to train the Model 2
after getting the output from the pre-trained model Model 1
properly?
What I am trying is given below -
with tf.Session() as sess:
# Initialize all the variables of the graph
seq2seq_model = load_seq2seq_model(sess)
sess.run(tf.initialize_all_variables())
.... Rest of the training code goes here....
All variables that are restored using a saver don't need to be initialized. Therefore, instead of using tf.initialize_all_variables()
you can use tf.variables_initializer(var_list)
to only initialize the weights of the second network.
To get a list of all the weights of the second network you can create the Model 2
network in a variable scope:
with tf.variable_scope("model2"):
model2 = Model2(...)
Then use
model_2_variables_list = tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES,
scope="model2"
)
to get the variable list of the Model 2
network. Finally you can create the initialisier for the second network:
init2 = tf.variables_initializer(model_2_variables_list)
with tf.Session() as sess:
# Initialize all the variables of the graph
seq2seq_model = load_seq2seq_model(sess)
sess.run(init2)
.... Rest of the training code goes here....
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.