简体   繁体   English

将 tf.contrib.layers.batch_norm 迁移到 Tensorflow 2.0

[英]Migrate tf.contrib.layers.batch_norm to Tensorflow 2.0

I'm migrating a TensorFlow code to Tensorflow 2.1.0.我正在将 TensorFlow 代码迁移到 Tensorflow 2.1.0。

Here is the original code:这是原始代码:

conv = tf.layers.conv2d(inputs, out_channels, kernel_size=3, padding='SAME')
conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
conv = tf.nn.relu(conv)
conv = tf.contrib.layers.max_pool2d(conv, 2)

And this is what I've done:这就是我所做的:

conv1 = Conv2D(out_channels, (3, 3), activation='relu', padding='same', data_format='channels_last', name=name)(inputs)
conv1 = Conv2D(64, (5, 5), activation='relu', padding='same', data_format="channels_last")(conv1)
#conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
pool1 = MaxPooling2D(pool_size=(2, 2), data_format="channels_last")(conv1)

My problem is that I don't know what to do with tf.contrib.layers.batch_norm .我的问题是我不知道如何处理tf.contrib.layers.batch_norm

How can I migrate tf.contrib.layers.batch_norm to Tensorflow 2.x?如何将tf.contrib.layers.batch_norm迁移到 Tensorflow 2.x?

UPDATE:更新:
Using the comment suggestion, I think I have migrated correctly:使用评论建议,我认为我已经正确迁移:

conv1 = BatchNormalization(momentum=0.99, scale=True, center=True)(conv1)

But I'm not sure if decay is like momentum and I don't know how to set updates_collections in the BatchNormalization method.但我不确定decay是否像momentum ,我不知道如何在BatchNormalization方法中设置updates_collections

I encountered this problem when working with a trained model that I was going to fine tune.我在使用训练有素的 model 时遇到了这个问题,我将对其进行微调。 Just replacing tf.contrib.layers.batch_norm with tf.keras.layers.BatchNormalization like OP did gave me an error whose fix is described below.只是像 OP 一样用tf.keras.layers.BatchNormalization替换tf.contrib.layers.batch_norm确实给了我一个错误,其修复方法如下所述。

The old code looked like this:旧代码如下所示:

tf.contrib.layers.batch_norm(
    tensor,
    scale=True,
    center=True,
    is_training=self.use_batch_statistics,
    trainable=True,
    data_format=self._data_format,
    updates_collections=None,
)

and the updated working code looks like this:更新后的工作代码如下所示:

tf.keras.layers.BatchNormalization(
    name="BatchNorm",
    scale=True,
    center=True,
    trainable=True,
)(tensor)

I'm unsure if all the keyword arguments I removed are going to be a problem but everything seems to work.我不确定我删除的所有关键字 arguments 是否都会成为问题,但似乎一切正常。 Note the name="BatchNorm" argument.请注意name="BatchNorm"参数。 The layers use a different naming schema so I had to use the inspect_checkpoint.py tool to look at the model and find the layer names which happened to be BatchNorm .这些图层使用不同的命名模式,因此我不得不使用inspect_checkpoint.py工具查看 model 并找到恰好是BatchNorm的图层名称。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM