簡體   English   中英

將 tf.contrib.layers.batch_norm 遷移到 Tensorflow 2.0

[英]Migrate tf.contrib.layers.batch_norm to Tensorflow 2.0

我正在將 TensorFlow 代碼遷移到 Tensorflow 2.1.0。

這是原始代碼:

conv = tf.layers.conv2d(inputs, out_channels, kernel_size=3, padding='SAME')
conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
conv = tf.nn.relu(conv)
conv = tf.contrib.layers.max_pool2d(conv, 2)

這就是我所做的:

conv1 = Conv2D(out_channels, (3, 3), activation='relu', padding='same', data_format='channels_last', name=name)(inputs)
conv1 = Conv2D(64, (5, 5), activation='relu', padding='same', data_format="channels_last")(conv1)
#conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
pool1 = MaxPooling2D(pool_size=(2, 2), data_format="channels_last")(conv1)

我的問題是我不知道如何處理tf.contrib.layers.batch_norm

如何將tf.contrib.layers.batch_norm遷移到 Tensorflow 2.x?

更新:
使用評論建議,我認為我已經正確遷移:

conv1 = BatchNormalization(momentum=0.99, scale=True, center=True)(conv1)

但我不確定decay是否像momentum ,我不知道如何在BatchNormalization方法中設置updates_collections

我在使用訓練有素的 model 時遇到了這個問題,我將對其進行微調。 只是像 OP 一樣用tf.keras.layers.BatchNormalization替換tf.contrib.layers.batch_norm確實給了我一個錯誤,其修復方法如下所述。

舊代碼如下所示:

tf.contrib.layers.batch_norm(
    tensor,
    scale=True,
    center=True,
    is_training=self.use_batch_statistics,
    trainable=True,
    data_format=self._data_format,
    updates_collections=None,
)

更新后的工作代碼如下所示:

tf.keras.layers.BatchNormalization(
    name="BatchNorm",
    scale=True,
    center=True,
    trainable=True,
)(tensor)

我不確定我刪除的所有關鍵字 arguments 是否都會成為問題,但似乎一切正常。 請注意name="BatchNorm"參數。 這些圖層使用不同的命名模式,因此我不得不使用inspect_checkpoint.py工具查看 model 並找到恰好是BatchNorm的圖層名稱。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM