[英]what is the Tensorflow batch size when you use high-level API tf.contrib.learn.DNNClassifier
[英]How to implement Batch Normalization on tensorflow with Keras as a high-level API
在训练和推论中,BatchNormalization(BN)的操作略有不同。 在训练中,它使用当前迷你批次的平均值和方差来缩放其输入; 这意味着应用批量标准化的确切结果不仅取决于当前输入,还取决于微型批量的所有其他元素。 当在推理模式下我们想要确定的结果时,这显然是不希望的。 因此,在这种情况下,将使用整个训练集的全局平均值和方差的固定统计量。
在Tensorflow中,此行为由布尔开关training
控制,在调用图层时需要指定布尔开关training
,请参阅https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization 。 使用Keras高级API时如何处理此开关? 我是否假设根据我们使用的是model.fit(x, ...)
还是model.predict(x, ...)
是自动处理的方法正确吗?
为了测试这一点,我编写了这个示例。 我们从随机分布开始,我们想对输入是正还是负进行分类。 但是,我们还有一个来自不同分布的测试数据集,其中输入偏移了2(因此标签检查x是否大于2)。
import numpy as np
from math import ceil
from tensorflow.python.data import Dataset
from tensorflow.python.keras import Input, Model
from tensorflow.python.keras.layers import Dense, BatchNormalization
np.random.seed(18)
xt = np.random.randn(10_000, 1)
yt = np.array([[int(x > 0)] for x in xt])
train_data = Dataset.from_tensor_slices((xt, yt)).shuffle(10_000).repeat().batch(32).prefetch(2)
xv = np.random.randn(100, 1)
yv = np.array([[int(x > 0)] for x in xv])
valid_data = Dataset.from_tensor_slices((xv, yv)).repeat().batch(32).prefetch(2)
xs = np.random.randn(100, 1) + 2
ys = np.array([[int(x > 2)] for x in xs])
test_data = Dataset.from_tensor_slices((xs, ys)).repeat().batch(32).prefetch(2)
x = Input(shape=(1,))
a = BatchNormalization()(x)
a = Dense(8, activation='sigmoid')(a)
a = BatchNormalization()(a)
y = Dense(1, activation='sigmoid')(a)
model = Model(inputs=x, outputs=y, )
model.summary()
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(train_data, epochs=10, steps_per_epoch=ceil(10_000 / 32), validation_data=valid_data,
validation_steps=ceil(100 / 32))
zs = model.predict(test_data, steps=ceil(100 / 32))
print(sum([ys[i] == int(zs[i] > 0.5) for i in range(100)]))
运行代码将输出值0.5,这意味着一半示例已正确标记。 如果系统使用培训集上的全局统计数据来实施BN,这就是我所期望的。
如果我们将BN层更改为读取
x = Input(shape=(1,))
a = BatchNormalization()(x, training=True)
a = Dense(8, activation='sigmoid')(a)
a = BatchNormalization()(a, training=True)
y = Dense(1, activation='sigmoid')(a)
然后再次运行代码,我们发现0.87。 始终强制训练状态,正确预测的百分比已更改。 这与model.predict(x, ...)
现在使用小批量的统计信息来实现BN的想法是一致的,因此能够稍微“校正”训练和测试之间的源分布不匹配。数据。
那是对的吗?
如果我正确地理解了您的问题,那么是的,keras会根据fit
与predict
/ evaluate
自动管理训练与推理行为。 该标志称为learning_phase
,它确定批处理规范,退出和可能的其他情况的行为。 当前的学习阶段可以通过keras.backend.learning_phase()
,并通过keras.backend.set_learning_phase()
设置。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.