繁体   English   中英

如何恢复预训练模型以初始化参数

[英]How to restore pretrained model to initialize parameters

我已经下载了一个带有预训练模型的网络。 我在网络中添加了几个图层和参数,我想使用这个预训练模型初始化原始参数,并随机初始化新添加的参数。我使用此代码:

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "output/saver-test")
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

但是我遇到了错误:“检查点中找不到关键的global_step”,这个错误因为我有一些在预训练模型中不存在的新参数。但是我该如何解决这个问题呢? 更重要的是,我想使用这个代码“sess.run(tf.global_variables_initializer())”来初始化新添加的参数,但是从预训练模型中提取的参数将被它覆盖?

这是因为您的网络与加载的网络不完全匹配。 您可以使用类似的选择性检查点加载器:

  reader = tf.train.NewCheckpointReader(os.path.join(checkpoint_dir, ckpt_name))
    restore_dict = dict()
    for v in tf.trainable_variables():
        tensor_name = v.name.split(':')[0]
        if reader.has_tensor(tensor_name):
            print('has tensor ', tensor_name)
            restore_dict[tensor_name] = v

    restore_dict['my_new_var_scope/my_new_var'] = self.get_my_new_var_variable()

get_my_new_var_variable()就是这样的:

    def get_my_new_var_variable(self):
    with tf.variable_scope("my_new_var_scope",reuse=tf.AUTO_REUSE):
        my_new_var = tf.get_variable("my_new_var", dtype=tf.int32,initializer=tf.constant([23, 42]))
    return my_new_var

装载重量:

self.saver = tf.train.Saver(restore_dict)
    self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))

编辑:

请注意,为了避免覆盖已加载的变量,您可以使用此方法:

def initialize_uninitialized(sess):
  global_vars = tf.global_variables()
  is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
  not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
  if len(not_initialized_vars):
    sess.run(tf.variables_initializer(not_initialized_vars))

或者只是在加载变量之前调用tf.global_variables_initializer()应该在这里工作。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM