简体   繁体   中英

Tensorflow: how to save variables and load them to different variables?

Let's say I have two identical networks, A and B . I saved (using Saver ) a previous state of network A , and now I would like to load it into network B (all happens during the same run). How can I do this?

Let me provide an example. First, let's define and save some variables:

import tensorflow as tf

v1 = tf.Variable(tf.ones(1), name='v1')
v2 = tf.Variable(2 * tf.ones(1), name='v2')

saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, './tmp.ckpt')

Now, let's define some variables with the same names in a new graph, and load their values from the checkpoint:

with tf.Graph().as_default():
    assert len(tf.trainable_variables()) == 0
    v1 = tf.Variable(tf.zeros(1), name='v1')
    v2 = tf.Variable(tf.zeros(1), name='v2')

    saver = tf.train.Saver(tf.trainable_variables())
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, './tmp.ckpt')
        print(sess.run([v1, v2]))

The last line prints:

[array([1.], dtype=float32), array([2.], dtype=float32)]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM