简体   繁体   中英

Tensorflow `tf.layers.batch_normalization` doesn't add update ops to `tf.GraphKeys.UPDATE_OPS`

The following code (copy/paste runnable) illustrates using tf.layers.batch_normalization .

import tensorflow as tf
bn = tf.layers.batch_normalization(tf.constant([0.0]))
print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))

> []     # UPDATE_OPS collection is empty

Using TF 1.5, the documentation (quoted below) clearly states that UPDATE_OPS should not be empty in this case ( https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization ):

Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS , so they need to be added as a dependency to the train_op. For example:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss)

Just change your code to be in training mode (by setting the training flag to True ) as mentioned in the quote :

Note: when training , the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op.

 import tensorflow as tf
 bn = tf.layers.batch_normalization(tf.constant([0.0]), training=True)
 print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))

will output:

[< tf.Tensor 'batch_normalization/AssignMovingAvg:0' shape=(1,) dtype=float32_ref>, 
 < tf.Tensor 'batch_normalization/AssignMovingAvg_1:0' shape=(1,) dtype=float32_ref>]

and Gamma and Beta end up in the TRAINABLE_VARIABLES collection:

print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))

[<tf.Variable 'batch_normalization/gamma:0' shape=(1,) dtype=float32_ref>, 
 <tf.Variable 'batch_normalization/beta:0' shape=(1,) dtype=float32_ref>]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM