简体   繁体   中英

Reset tensorflow Optimizer

I am loading from a saved model and I would like to be able to reset a tensorflow optimizer such as an Adam Optimizer. Ideally something like:

sess.run([tf.initialize_variables(Adamopt)])

or

sess.run([Adamopt.reset])

I have tried looking for an answer but have yet to find any way to do it. Here's what I've found which don't address the issue: https://github.com/tensorflow/tensorflow/issues/634

In TensorFlow is there any way to just initialize uninitialised variables?

Tensorflow: Using Adam optimizer

I basically just want a way to reset the "slot" variables in the Adam Optimizer.

Thanks

This question also bothered me for quite a while. Actually it's quite easy, you just define an operation to reset the current state of an optimizer which can be obtained by the variables() method, something like this:

optimizer = tf.train.AdamOptimizer(0.1, name='Optimizer')
reset_optimizer_op = tf.variables_initializer(optimizer.variables())

Whenever you need to reset the optimizer, run:

sess.run(reset_optimizer_op)

Official explanation of variables():

A list of variables which encode the current state of Optimizer. Includes slot variables and additional global variables created by the optimizer in the current default graph.

eg for AdamOptimizer basically you will get the first and second moment(with slot_name 'm' and 'v') of all trainable variables, as long as beta1_power and beta2_power.

In tensorflow 2.x, eg, Adam optimizer, you can reset it like this:

for var in optimizer.variables():
    var.assign(tf.zeros_like(var))

The simplest way I found was to give the optimizer its own variable scope and then run

optimizer_scope = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                 "scope/prefix/for/optimizer")
sess.run(tf.initialize_variables(optimizer_scope))

idea from freeze weights

Building upon @EdisonLeejt answer for Tensorflow 2.x, more generally you can first get the initial state (may not be zero eg if loaded from a checkpoint file) and then assign it ie

#Get initial states
init_states = [var.value() for var in optimizer.variables()]

#Do the optimization
...

#reset optimizer state to init_state
for val,var in zip(init_states,optimizer.variables()): var.assign(val)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM