[英]How to get current learning rate of SGD optimizer in TensorFlow 2.0 when I use tf.keras.optimizers.schedules.ExponentialDecay?
I want to reduce learning rate in SGD optimizer of tensorflow2.0, I used this line of code:我想在 tensorflow2.0 的 SGD 优化器中降低学习率,我使用了这行代码:
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=self.parameter['learning_rate'],
decay_steps=(1000),
decay_rate=self.parameter['lr_decay']
)
opt = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)
But I don't know if my learning rate has dropped, how can I get my current learning rate?但是不知道是不是我的学习率下降了,怎么才能得到我当前的学习率呢?
_decayed_lr method decays the learning_rate based on the number of iterations as _decayed_lr方法根据迭代次数衰减 learning_rate
and returns the actual learning rate at that specific iteration.
并返回该特定迭代的实际学习率。 It also casts the returned value to a type that you specify.
它还将返回的值转换为您指定的类型。 So, the following code can do the job for you:
因此,以下代码可以为您完成这项工作:
opt._decayed_lr(tf.float32)
@Lisanu's answer worked for me as well. @Lisanu 的回答也对我有用。
Here's why&how that answer works:这就是该答案起作用的原因和方式:
This tensorflow's github webpage shows the codes for tf.keras.optimizers
. 这个 tensorflow 的 github 网页显示了
tf.keras.optimizers
的代码。
If you scroll down, there is a function named _decayed_lr
which allows users to get the decayed learning rate as a Tensor with dtype=var_dtype.如果向下滚动,有一个名为
_decayed_lr
的 function 允许用户以 dtype=var_dtype 的张量形式获得衰减的学习率。
Therefore, by using optimizer._decayed_lr(tf.float32)
, we can get the current decayed learning rate.因此,通过使用
optimizer._decayed_lr(tf.float32)
,我们可以得到当前的衰减学习率。
If you'd like to print the current decayed learning rate during training in Tensorflow, you can define a custom-callback class and utilize optimizer._decayed_lr(tf.float32)
.如果您想在 Tensorflow 训练期间打印当前衰减的学习率,您可以定义自定义回调 class 并利用
optimizer._decayed_lr(tf.float32)
。 The example is as follows:示例如下:
class CustomCallback(tf.keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
current_decayed_lr = self.model.optimizer._decayed_lr(tf.float32).numpy()
print("current decayed lr: {:0.7f}".format(current_decayed_lr))
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.