簡體   English   中英

如何使用Keras作為高級API在Tensorflow上實現批量規范化

[英]How to implement Batch Normalization on tensorflow with Keras as a high-level API

在訓練和推論中,BatchNormalization(BN)的操作略有不同。 在訓練中,它使用當前迷你批次的平均值和方差來縮放其輸入; 這意味着應用批量標准化的確切結果不僅取決於當前輸入,還取決於微型批量的所有其他元素。 當在推理模式下我們想要確定的結果時,這顯然是不希望的。 因此,在這種情況下,將使用整個訓練集的全局平均值和方差的固定統計量。

在Tensorflow中,此行為由布爾開關training控制,在調用圖層時需要指定布爾開關training ,請參閱https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization 使用Keras高級API時如何處理此開關? 我是否假設根據我們使用的是model.fit(x, ...)還是model.predict(x, ...)是自動處理的方法正確嗎?


為了測試這一點,我編寫了這個示例。 我們從隨機分布開始,我們想對輸入是正還是負進行分類。 但是,我們還有一個來自不同分布的測試數據集,其中輸入偏移了2(因此標簽檢查x是否大於2)。

import numpy as np
from math import ceil
from tensorflow.python.data import Dataset
from tensorflow.python.keras import Input, Model
from tensorflow.python.keras.layers import Dense, BatchNormalization

np.random.seed(18)
xt = np.random.randn(10_000, 1)
yt = np.array([[int(x > 0)] for x in xt])
train_data = Dataset.from_tensor_slices((xt, yt)).shuffle(10_000).repeat().batch(32).prefetch(2)

xv = np.random.randn(100, 1)
yv = np.array([[int(x > 0)] for x in xv])
valid_data = Dataset.from_tensor_slices((xv, yv)).repeat().batch(32).prefetch(2)

xs = np.random.randn(100, 1) + 2
ys = np.array([[int(x > 2)] for x in xs])
test_data = Dataset.from_tensor_slices((xs, ys)).repeat().batch(32).prefetch(2)

x = Input(shape=(1,))
a = BatchNormalization()(x)
a = Dense(8, activation='sigmoid')(a)
a = BatchNormalization()(a)
y = Dense(1, activation='sigmoid')(a)
model = Model(inputs=x, outputs=y, )
model.summary()

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(train_data, epochs=10, steps_per_epoch=ceil(10_000 / 32), validation_data=valid_data,
          validation_steps=ceil(100 / 32))
zs = model.predict(test_data, steps=ceil(100 / 32))
print(sum([ys[i] == int(zs[i] > 0.5) for i in range(100)]))

運行代碼將輸出值0.5,這意味着一半示例已正確標記。 如果系統使用培訓集上的全局統計數據來實施BN,這就是我所期望的。

如果我們將BN層更改為讀取

x = Input(shape=(1,))
a = BatchNormalization()(x, training=True)
a = Dense(8, activation='sigmoid')(a)
a = BatchNormalization()(a, training=True)
y = Dense(1, activation='sigmoid')(a)

然后再次運行代碼,我們發現0.87。 始終強制訓練狀態,正確預測的百分比已更改。 這與model.predict(x, ...)現在使用小批量的統計信息來實現BN的想法是一致的,因此能夠稍微“校正”訓練和測試之間的源分布不匹配。數據。

那是對的嗎?

如果我正確地理解了您的問題,那么是的,keras會根據fitpredict / evaluate自動管理訓練與推理行為。 該標志稱為learning_phase ,它確定批處理規范,退出和可能的其他情況的行為。 當前的學習階段可以通過keras.backend.learning_phase() ,並通過keras.backend.set_learning_phase()設置。

https://keras.io/backend/#learning_phase

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM