[英]Tensorflow: how to save variables and load them to different variables?
Let's say I have two identical networks, A
and B
.假设我有两个相同的网络A
和B
。 I saved (using Saver
) a previous state of network A
, and now I would like to load it into network B
(all happens during the same run).我保存(使用Saver
)网络A
的先前状态,现在我想将它加载到网络B
(所有这些都发生在同一次运行中)。 How can I do this?我怎样才能做到这一点?
Let me provide an example.让我举一个例子。 First, let's define and save some variables:首先,让我们定义并保存一些变量:
import tensorflow as tf
v1 = tf.Variable(tf.ones(1), name='v1')
v2 = tf.Variable(2 * tf.ones(1), name='v2')
saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, './tmp.ckpt')
Now, let's define some variables with the same names in a new graph, and load their values from the checkpoint:现在,让我们在新图中定义一些具有相同名称的变量,并从检查点加载它们的值:
with tf.Graph().as_default():
assert len(tf.trainable_variables()) == 0
v1 = tf.Variable(tf.zeros(1), name='v1')
v2 = tf.Variable(tf.zeros(1), name='v2')
saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, './tmp.ckpt')
print(sess.run([v1, v2]))
The last line prints:最后一行打印:
[array([1.], dtype=float32), array([2.], dtype=float32)]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.