繁体   English   中英

如何在Tensorflow中为当前模型恢复预训练的检查点?

[英]How to restore pretrained checkpoint for current model in Tensorflow?

我有一个训练有素的检查站。 现在,我正在尝试将此预训练模型恢复到当前网络。 但是,变量名不同。 Tensorflow文档说使用字典像:

v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
saver = tf.train.Saver({"v2": v2})

但是,当前网络中的变量定义如下:

with tf.variable_scope('a'):
    b=tf.get_variable(......)

因此,变量名称似乎是a/b 如何使字典像"v2": a/b

您可以使用tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)获取当前图中所有变量名称的列表。 您还可以指定范围。

tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='a')

您可以使用tf.train.list_variables(ckpt_file)获取检查点中所有变量的列表。

假设您的检查点中有变量b,并且您想在名称为a/b tf.variable_scope('a')内部加载。 为此,您只需定义它

with tf.variable_scope('a'):
    b=tf.get_variable(......)

和负载

saver = tf.train.Saver({'v2': b})

with tf.Session() as sess:
    saver.restore(sess, ckpt_file))
    print(b)

这将输出

<tf.Variable 'a/b:0' shape dtype>

编辑:如前所述,您可以使用

vars_dict = {}
for var_current in tf.global_variables():
    print(var_current)
    print(var_current.op.name) # this gets only name

for var_ckpt in tf.train.list_variables(ckpt):
    print(var_ckpt[0]) this gets only name

当您知道所有变量的确切名称时,您可以分配所需的任何值,只要变量具有相同的形状和dtype即可

vars_dict[var_ckpt[0]) = tf.get_variable(var_current.op.name, shape) # remember to specify shape, you can always get it from var_current 

您可以显式地或以您认为合适的任何循环构造此字典。 然后将其传递给保护程序

saver = tf.train.Saver(vars_dict)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM