![](/img/trans.png)
[英]Rewrite tf.contrib.layers.batch_norm in Tensorflow 2.0
[英]Migrate tf.contrib.layers.batch_norm to Tensorflow 2.0
我正在將 TensorFlow 代碼遷移到 Tensorflow 2.1.0。
這是原始代碼:
conv = tf.layers.conv2d(inputs, out_channels, kernel_size=3, padding='SAME')
conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
conv = tf.nn.relu(conv)
conv = tf.contrib.layers.max_pool2d(conv, 2)
這就是我所做的:
conv1 = Conv2D(out_channels, (3, 3), activation='relu', padding='same', data_format='channels_last', name=name)(inputs)
conv1 = Conv2D(64, (5, 5), activation='relu', padding='same', data_format="channels_last")(conv1)
#conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
pool1 = MaxPooling2D(pool_size=(2, 2), data_format="channels_last")(conv1)
我的問題是我不知道如何處理tf.contrib.layers.batch_norm
。
如何將tf.contrib.layers.batch_norm
遷移到 Tensorflow 2.x?
更新:
使用評論建議,我認為我已經正確遷移:
conv1 = BatchNormalization(momentum=0.99, scale=True, center=True)(conv1)
但我不確定decay
是否像momentum
,我不知道如何在BatchNormalization
方法中設置updates_collections
。
我在使用訓練有素的 model 時遇到了這個問題,我將對其進行微調。 只是像 OP 一樣用tf.keras.layers.BatchNormalization
替換tf.contrib.layers.batch_norm
確實給了我一個錯誤,其修復方法如下所述。
舊代碼如下所示:
tf.contrib.layers.batch_norm(
tensor,
scale=True,
center=True,
is_training=self.use_batch_statistics,
trainable=True,
data_format=self._data_format,
updates_collections=None,
)
更新后的工作代碼如下所示:
tf.keras.layers.BatchNormalization(
name="BatchNorm",
scale=True,
center=True,
trainable=True,
)(tensor)
我不確定我刪除的所有關鍵字 arguments 是否都會成為問題,但似乎一切正常。 請注意name="BatchNorm"
參數。 這些圖層使用不同的命名模式,因此我不得不使用inspect_checkpoint.py
工具查看 model 並找到恰好是BatchNorm
的圖層名稱。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.