[英]Learning rate and weight decay schedule in Tensorflow SGDW optimizer
I'm trying to reproduce part of this paper with TensorFlow, the problem is that the authors use SGD with weight decay, cutting the learning rate to 1/10 every 30 epochs.我试图用 TensorFlow 复制本文的一部分,问题是作者使用带有权重衰减的 SGD,每 30 个时期将学习率降低到 1/10。
TensorFlow documentation says that TensorFlow 文档说
when applying a decay to the learning rate, be sure to manually apply the decay to the weight_decay as well
当对学习率应用衰减时,请务必手动将衰减应用于 weight_decay
So I tried with所以我尝试了
schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate = 0.003,
decay_rate = 0.1,
decay_steps = steps_per_epoch*30,
staircase = True
)
optimizer = tfa.optimizers.SGDW(learning_rate = schedule,
weight_decay = schedule,
momentum = 0.9
)
(steps_per_epoch previously initialized) (steps_per_epoch 之前已初始化)
As I would with Keras SGD, this however doesn't work and raises a "TypeError: Expected float32" for the decay_weight parameter.正如我对 Keras SGD 所做的那样,但这不起作用并为衰变重量参数引发“TypeError:预期的 float32”。 What's the correct way to achieve the target behaviour?
实现目标行为的正确方法是什么?
You are getting an error because you are using keras ExponentialDecay
inside tensorflow add-on optimizer SGDW
.您收到错误是因为您在 tensorflow 附加优化器
SGDW
内使用 keras ExponentialDecay
。
As per the paper hyper-parameters are根据论文的超参数是
So why not use LearningRateScheduler
to reduce it by factor to 10 evey 30 epochs那么为什么不使用
LearningRateScheduler
将其减少到 10 evey 30 epochs
model = tf.keras.Sequential([
tf.keras.layers.Dense(8, input_shape=(10,)),
tf.keras.layers.Dense(4, activation='relu'),
tf.keras.layers.Dense(3, activation='softmax'),
])
X = np.random.randn(10,10)
y = np.random.randint(0,4,(10,3))
model.compile(
optimizer=tfa.optimizers.SGDW(
weight_decay=0.001,
momentum=0.9,
learning_rate=0.003),
loss=tf.keras.losses.categorical_crossentropy)
def scheduler(epoch, lr):
if epoch % 30 == 0:
lr = lr*0.1
return lr
callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
model.fit(X, y, callbacks=[callback], epochs=100)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.