繁体   English   中英

如何使用Keras作为高级API在Tensorflow上实现批量规范化

[英]How to implement Batch Normalization on tensorflow with Keras as a high-level API

在训练和推论中,BatchNormalization(BN)的操作略有不同。 在训练中,它使用当前迷你批次的平均值和方差来缩放其输入; 这意味着应用批量标准化的确切结果不仅取决于当前输入,还取决于微型批量的所有其他元素。 当在推理模式下我们想要确定的结果时,这显然是不希望的。 因此,在这种情况下,将使用整个训练集的全局平均值和方差的固定统计量。

在Tensorflow中,此行为由布尔开关training控制,在调用图层时需要指定布尔开关training ,请参阅https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization 使用Keras高级API时如何处理此开关? 我是否假设根据我们使用的是model.fit(x, ...)还是model.predict(x, ...)是自动处理的方法正确吗?


为了测试这一点,我编写了这个示例。 我们从随机分布开始,我们想对输入是正还是负进行分类。 但是,我们还有一个来自不同分布的测试数据集,其中输入偏移了2(因此标签检查x是否大于2)。

import numpy as np
from math import ceil
from tensorflow.python.data import Dataset
from tensorflow.python.keras import Input, Model
from tensorflow.python.keras.layers import Dense, BatchNormalization

np.random.seed(18)
xt = np.random.randn(10_000, 1)
yt = np.array([[int(x > 0)] for x in xt])
train_data = Dataset.from_tensor_slices((xt, yt)).shuffle(10_000).repeat().batch(32).prefetch(2)

xv = np.random.randn(100, 1)
yv = np.array([[int(x > 0)] for x in xv])
valid_data = Dataset.from_tensor_slices((xv, yv)).repeat().batch(32).prefetch(2)

xs = np.random.randn(100, 1) + 2
ys = np.array([[int(x > 2)] for x in xs])
test_data = Dataset.from_tensor_slices((xs, ys)).repeat().batch(32).prefetch(2)

x = Input(shape=(1,))
a = BatchNormalization()(x)
a = Dense(8, activation='sigmoid')(a)
a = BatchNormalization()(a)
y = Dense(1, activation='sigmoid')(a)
model = Model(inputs=x, outputs=y, )
model.summary()

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(train_data, epochs=10, steps_per_epoch=ceil(10_000 / 32), validation_data=valid_data,
          validation_steps=ceil(100 / 32))
zs = model.predict(test_data, steps=ceil(100 / 32))
print(sum([ys[i] == int(zs[i] > 0.5) for i in range(100)]))

运行代码将输出值0.5,这意味着一半示例已正确标记。 如果系统使用培训集上的全局统计数据来实施BN,这就是我所期望的。

如果我们将BN层更改为读取

x = Input(shape=(1,))
a = BatchNormalization()(x, training=True)
a = Dense(8, activation='sigmoid')(a)
a = BatchNormalization()(a, training=True)
y = Dense(1, activation='sigmoid')(a)

然后再次运行代码,我们发现0.87。 始终强制训练状态,正确预测的百分比已更改。 这与model.predict(x, ...)现在使用小批量的统计信息来实现BN的想法是一致的,因此能够稍微“校正”训练和测试之间的源分布不匹配。数据。

那是对的吗?

如果我正确地理解了您的问题,那么是的,keras会根据fitpredict / evaluate自动管理训练与推理行为。 该标志称为learning_phase ,它确定批处理规范,退出和可能的其他情况的行为。 当前的学习阶段可以通过keras.backend.learning_phase() ,并通过keras.backend.set_learning_phase()设置。

https://keras.io/backend/#learning_phase

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM