简体   繁体   中英

How to reset a variable in tensorflow?

I want to execute a graph, change the value of a single variable, and then execute the graph again with the change absorbed by any downstream variables. As an example,

A = tf.distributions.Normal(0.0, 1.0)
B = tf.distributions.Normal(0.0, 1.0)

a = A.get_variable(name="a", initializer=A.sample)
b = B.get_variable(name="b", initializer=B.sample)

C = tf.distributions.Normal(a + b, 1.0)

c = C.get_variable(name="c", initializer=C.sample)

So, if I run this graph,

session.run(tf.global_variables_initializer())

with tf.Session() as session:
    session.run([a, b, c])

I get a set of values for a, b, and c. Then, say, I want to re-initialize b,

    session.run(b.initializer)
    session.run([a, b, c])

This will re-initialize the value of b but that change is not propagated to c. Since b has changed, and c depends on b (through C), I want c to be re-initialized.

Is this possible in tensorflow?

c does not depend on b , its initializer does. In fact, variables in tensorflow can't depend on other variables. After you re-initialize b , the value of c doesn't change. But if you run the c initializer again, it will pick up the new value. So simply do:

session.run(b.initializer)
session.run(c.initializer)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM