简体   繁体   English

Tensorflow:如何保存变量并将它们加载到不同的变量?

[英]Tensorflow: how to save variables and load them to different variables?

Let's say I have two identical networks, A and B .假设我有两个相同的网络AB I saved (using Saver ) a previous state of network A , and now I would like to load it into network B (all happens during the same run).我保存(使用Saver )网络A的先前状态,现在我想将它加载到网络B (所有这些都发生在同一次运行中)。 How can I do this?我怎样才能做到这一点?

Let me provide an example.让我举一个例子。 First, let's define and save some variables:首先,让我们定义并保存一些变量:

import tensorflow as tf

v1 = tf.Variable(tf.ones(1), name='v1')
v2 = tf.Variable(2 * tf.ones(1), name='v2')

saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, './tmp.ckpt')

Now, let's define some variables with the same names in a new graph, and load their values from the checkpoint:现在,让我们在新图中定义一些具有相同名称的变量,并从检查点加载它们的值:

with tf.Graph().as_default():
    assert len(tf.trainable_variables()) == 0
    v1 = tf.Variable(tf.zeros(1), name='v1')
    v2 = tf.Variable(tf.zeros(1), name='v2')

    saver = tf.train.Saver(tf.trainable_variables())
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, './tmp.ckpt')
        print(sess.run([v1, v2]))

The last line prints:最后一行打印:

[array([1.], dtype=float32), array([2.], dtype=float32)]

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM