简体   繁体   中英

Unable restore variables of Adam Optimizer while using tf.train.save

I get following errors when I try to restore a saved model in tensorflow:

 W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key out_w/Adam_5 not found in checkpoint
 W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key b1/Adam not found in checkpoint
 W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key b1/Adam_4 not found in checkpoint

I guess I am unable to save Variables of Adam Optimizer. Any fix?

Consider this small experiment:

import tensorflow as tf

def simple_model(X):
    with tf.variable_scope('Layer1'):
        w1 = tf.get_variable('w1', initializer=tf.truncated_normal((5, 2)))
        b1 = tf.get_variable('b1', initializer=tf.ones((2)))
        layer1 = tf.matmul(X, w1) + b1
    return layer1

def simple_model2(X):
    with tf.variable_scope('Layer1'):
        w1 = tf.get_variable('w1_x', initializer=tf.truncated_normal((5, 2)))
        b1 = tf.get_variable('b1_x', initializer=tf.ones((2)))
        layer1 = tf.matmul(X, w1) + b1
    return layer1

tf.reset_default_graph()
X = tf.placeholder(tf.float32, shape = (None, 5))
model = simple_model(X)
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, './Checkpoint', global_step = 0)

tf.reset_default_graph()
X = tf.placeholder(tf.float32, shape = (None, 5))
model = simple_model(X)      # Case 1
#model = simple_model2(X)    # Case 2
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    tf.train.Saver().restore(sess, tf.train.latest_checkpoint('.'))

In Case 1, everything works fine. But in Case2, you will get errors like Key Layer1/b1_x not found in checkpoint which is because the variable names in the model are different (though the shapes and datatypes of both variables are same). Ensure that variables are having same names in the model in which you are restoring.

To check the names of the variables present in the checkpoint, check this answer .

This can also happen when you are not training every variable simultaneously, due to only partially available adam parameters in a checkpoint.

One possible fix would be to "reset" Adam after loading the checkpoint. To to this, filter adam-related variables when creating the saver:

vl = [v for v in tf.global_variables() if "Adam" not in v.name]
saver = tf.train.Saver(var_list=vl)

Make sure to initialize global variables afterwards.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM