I'm migrating a TensorFlow code to Tensorflow 2.1.0.
Here is the original code:
conv = tf.layers.conv2d(inputs, out_channels, kernel_size=3, padding='SAME')
conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
conv = tf.nn.relu(conv)
conv = tf.contrib.layers.max_pool2d(conv, 2)
And this is what I've done:
conv1 = Conv2D(out_channels, (3, 3), activation='relu', padding='same', data_format='channels_last', name=name)(inputs)
conv1 = Conv2D(64, (5, 5), activation='relu', padding='same', data_format="channels_last")(conv1)
#conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
pool1 = MaxPooling2D(pool_size=(2, 2), data_format="channels_last")(conv1)
My problem is that I don't know what to do with tf.contrib.layers.batch_norm
.
How can I migrate tf.contrib.layers.batch_norm
to Tensorflow 2.x?
UPDATE:
Using the comment suggestion, I think I have migrated correctly:
conv1 = BatchNormalization(momentum=0.99, scale=True, center=True)(conv1)
But I'm not sure if decay
is like momentum
and I don't know how to set updates_collections
in the BatchNormalization
method.
I encountered this problem when working with a trained model that I was going to fine tune. Just replacing tf.contrib.layers.batch_norm
with tf.keras.layers.BatchNormalization
like OP did gave me an error whose fix is described below.
The old code looked like this:
tf.contrib.layers.batch_norm(
tensor,
scale=True,
center=True,
is_training=self.use_batch_statistics,
trainable=True,
data_format=self._data_format,
updates_collections=None,
)
and the updated working code looks like this:
tf.keras.layers.BatchNormalization(
name="BatchNorm",
scale=True,
center=True,
trainable=True,
)(tensor)
I'm unsure if all the keyword arguments I removed are going to be a problem but everything seems to work. Note the name="BatchNorm"
argument. The layers use a different naming schema so I had to use the inspect_checkpoint.py
tool to look at the model and find the layer names which happened to be BatchNorm
.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.