简体   繁体   English

如何使用Keras作为高级API在Tensorflow上实现批量规范化

[英]How to implement Batch Normalization on tensorflow with Keras as a high-level API

BatchNormalization (BN) operates slightly differently when in training and in inference. 在训练和推论中,BatchNormalization(BN)的操作略有不同。 In training, it uses the average and variance of the current mini-batch to scale its inputs; 在训练中,它使用当前迷你批次的平均值和方差来缩放其输入; this means that the exact result of the application of batch normalization depends not only on the current input, but also on all other elements of the mini-batch. 这意味着应用批量标准化的确切结果不仅取决于当前输入,还取决于微型批量的所有其他元素。 This is clearly not desirable when in inference mode, where we want a deterministic result. 当在推理模式下我们想要确定的结果时,这显然是不希望的。 Therefore, in that case, a fixed statistic of the global average and variance over the entire training set is used. 因此,在这种情况下,将使用整个训练集的全局平均值和方差的固定统计量。

In Tensorflow, this behavior is controlled by a boolean switch training that needs to be specified when calling the layer, see https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization . 在Tensorflow中,此行为由布尔开关training控制,在调用图层时需要指定布尔开关training ,请参阅https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization How do I deal with this switch when using Keras high-level API? 使用Keras高级API时如何处理此开关? Am I correct in assuming that it is dealt with automatically, depending whether we are using model.fit(x, ...) or model.predict(x, ...) ? 我是否假设根据我们使用的是model.fit(x, ...)还是model.predict(x, ...)是自动处理的方法正确吗?


To test this, I have written this example. 为了测试这一点,我编写了这个示例。 We start with a random distribution and we want to classify whether the input is positive or negative. 我们从随机分布开始,我们想对输入是正还是负进行分类。 However, we also have a test dataset coming from a different distribution where the inputs are displaced by 2 (and consequently the labels check whether x>2). 但是,我们还有一个来自不同分布的测试数据集,其中输入偏移了2(因此标签检查x是否大于2)。

import numpy as np
from math import ceil
from tensorflow.python.data import Dataset
from tensorflow.python.keras import Input, Model
from tensorflow.python.keras.layers import Dense, BatchNormalization

np.random.seed(18)
xt = np.random.randn(10_000, 1)
yt = np.array([[int(x > 0)] for x in xt])
train_data = Dataset.from_tensor_slices((xt, yt)).shuffle(10_000).repeat().batch(32).prefetch(2)

xv = np.random.randn(100, 1)
yv = np.array([[int(x > 0)] for x in xv])
valid_data = Dataset.from_tensor_slices((xv, yv)).repeat().batch(32).prefetch(2)

xs = np.random.randn(100, 1) + 2
ys = np.array([[int(x > 2)] for x in xs])
test_data = Dataset.from_tensor_slices((xs, ys)).repeat().batch(32).prefetch(2)

x = Input(shape=(1,))
a = BatchNormalization()(x)
a = Dense(8, activation='sigmoid')(a)
a = BatchNormalization()(a)
y = Dense(1, activation='sigmoid')(a)
model = Model(inputs=x, outputs=y, )
model.summary()

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(train_data, epochs=10, steps_per_epoch=ceil(10_000 / 32), validation_data=valid_data,
          validation_steps=ceil(100 / 32))
zs = model.predict(test_data, steps=ceil(100 / 32))
print(sum([ys[i] == int(zs[i] > 0.5) for i in range(100)]))

Running the code prints the value 0.5, meaning that half the examples are labeled properly. 运行代码将输出值0.5,这意味着一半示例已正确标记。 This is what I would expect if the system was using the global statistics on the training set to implement BN. 如果系统使用培训集上的全局统计数据来实施BN,这就是我所期望的。

If we change the BN layers to read 如果我们将BN层更改为读取

x = Input(shape=(1,))
a = BatchNormalization()(x, training=True)
a = Dense(8, activation='sigmoid')(a)
a = BatchNormalization()(a, training=True)
y = Dense(1, activation='sigmoid')(a)

and run the code again we find 0.87. 然后再次运行代码,我们发现0.87。 Forcing always the training state, the percentage of correct prediction has changed. 始终强制训练状态,正确预测的百分比已更改。 This is consistent with the idea that model.predict(x, ...) is now using the statistic of the mini-batch to implement BN, and is therefore able to slightly "correct" the mismatch in the source distributions between training and test data. 这与model.predict(x, ...)现在使用小批量的统计信息来实现BN的想法是一致的,因此能够稍微“校正”训练和测试之间的源分布不匹配。数据。

Is that correct? 那是对的吗?

If I'm understanding your question correctly, then yes, keras does automatically manage training vs inference behavior based on fit vs predict / evaluate . 如果我正确地理解了您的问题,那么是的,keras会根据fitpredict / evaluate自动管理训练与推理行为。 The flag is called learning_phase , and it determines the behavior of batch norm, dropout, and potentially other things. 该标志称为learning_phase ,它确定批处理规范,退出和可能的其他情况的行为。 The current learning phase can be seen with keras.backend.learning_phase() , and set with keras.backend.set_learning_phase() . 当前的学习阶段可以通过keras.backend.learning_phase() ,并通过keras.backend.set_learning_phase()设置。

https://keras.io/backend/#learning_phase https://keras.io/backend/#learning_phase

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 使用高级API tf.contrib.learn.DNNClassifier时Tensorflow批处理大小是多少 - what is the Tensorflow batch size when you use high-level API tf.contrib.learn.DNNClassifier 如何在LSTM中实现Tensorflow批量规范化 - How to implement Tensorflow batch normalization in LSTM 使用NaN进行Tensorflow / Keras批量标准化 - Tensorflow / Keras Batch normalization with NaN 在没有高级API的情况下重新训练CNN - Retraining a CNN without a high-level API 如何为 tensorflow 多 GPU 代码实现批量归一化层 - How to implement batch normalization layer for tensorflow multi-GPU code 为什么使用tensorflow的估计器高级API和原始API的mnist分类的交叉熵损失在规模上不同? - why the cross entropy loss of mnist classification using tensorflow's estimator high-level API and raw API are different in scale? 如何在Keras中实现具有残差连接和批量归一化的一维卷积神经网络? - How to implement a 1D convolutional neural network with residual connections and batch-normalization in Keras? 通过 tensorflow 到 keras 再现标准化 - Reproduce normalization by tensorflow to keras 我怎样才能在TensorFlow中使用批量标准化? - How could I use batch normalization in TensorFlow? 澄清批量标准化如何在 Tensorflow 上工作 - Clarification of how Batch normalization works on Tensorflow
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM