![](/img/trans.png)
[英]what is the Tensorflow batch size when you use high-level API tf.contrib.learn.DNNClassifier
[英]How to implement Batch Normalization on tensorflow with Keras as a high-level API
在訓練和推論中,BatchNormalization(BN)的操作略有不同。 在訓練中,它使用當前迷你批次的平均值和方差來縮放其輸入; 這意味着應用批量標准化的確切結果不僅取決於當前輸入,還取決於微型批量的所有其他元素。 當在推理模式下我們想要確定的結果時,這顯然是不希望的。 因此,在這種情況下,將使用整個訓練集的全局平均值和方差的固定統計量。
在Tensorflow中,此行為由布爾開關training
控制,在調用圖層時需要指定布爾開關training
,請參閱https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization 。 使用Keras高級API時如何處理此開關? 我是否假設根據我們使用的是model.fit(x, ...)
還是model.predict(x, ...)
是自動處理的方法正確嗎?
為了測試這一點,我編寫了這個示例。 我們從隨機分布開始,我們想對輸入是正還是負進行分類。 但是,我們還有一個來自不同分布的測試數據集,其中輸入偏移了2(因此標簽檢查x是否大於2)。
import numpy as np
from math import ceil
from tensorflow.python.data import Dataset
from tensorflow.python.keras import Input, Model
from tensorflow.python.keras.layers import Dense, BatchNormalization
np.random.seed(18)
xt = np.random.randn(10_000, 1)
yt = np.array([[int(x > 0)] for x in xt])
train_data = Dataset.from_tensor_slices((xt, yt)).shuffle(10_000).repeat().batch(32).prefetch(2)
xv = np.random.randn(100, 1)
yv = np.array([[int(x > 0)] for x in xv])
valid_data = Dataset.from_tensor_slices((xv, yv)).repeat().batch(32).prefetch(2)
xs = np.random.randn(100, 1) + 2
ys = np.array([[int(x > 2)] for x in xs])
test_data = Dataset.from_tensor_slices((xs, ys)).repeat().batch(32).prefetch(2)
x = Input(shape=(1,))
a = BatchNormalization()(x)
a = Dense(8, activation='sigmoid')(a)
a = BatchNormalization()(a)
y = Dense(1, activation='sigmoid')(a)
model = Model(inputs=x, outputs=y, )
model.summary()
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(train_data, epochs=10, steps_per_epoch=ceil(10_000 / 32), validation_data=valid_data,
validation_steps=ceil(100 / 32))
zs = model.predict(test_data, steps=ceil(100 / 32))
print(sum([ys[i] == int(zs[i] > 0.5) for i in range(100)]))
運行代碼將輸出值0.5,這意味着一半示例已正確標記。 如果系統使用培訓集上的全局統計數據來實施BN,這就是我所期望的。
如果我們將BN層更改為讀取
x = Input(shape=(1,))
a = BatchNormalization()(x, training=True)
a = Dense(8, activation='sigmoid')(a)
a = BatchNormalization()(a, training=True)
y = Dense(1, activation='sigmoid')(a)
然后再次運行代碼,我們發現0.87。 始終強制訓練狀態,正確預測的百分比已更改。 這與model.predict(x, ...)
現在使用小批量的統計信息來實現BN的想法是一致的,因此能夠稍微“校正”訓練和測試之間的源分布不匹配。數據。
那是對的嗎?
如果我正確地理解了您的問題,那么是的,keras會根據fit
與predict
/ evaluate
自動管理訓練與推理行為。 該標志稱為learning_phase
,它確定批處理規范,退出和可能的其他情況的行為。 當前的學習階段可以通過keras.backend.learning_phase()
,並通過keras.backend.set_learning_phase()
設置。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.