繁体   English   中英

keras:如何在 model.train_on_batch() 中使用学习率衰减

[英]keras: how to use learning rate decay with model.train_on_batch()

在我当前的项目中,我使用 keras 的train_on_batch()函数进行训练,因为fit()函数不支持 GAN 所需的生成器和鉴别器的交替训练。 使用(例如)Adam 优化器,我必须在构造函数中指定optimizer = Adam(decay=my_decay)的学习率衰减,并将其交给模型编译方法。 如果我之后使用模型的fit()函数,这个工作很好,因为它负责在内部计算训练重复,但我不知道如何使用类似的构造自己设置这个值

counter = 0
for epoch in range(EPOCHS):
    for batch_idx in range(0, number_training_samples, BATCH_SIZE):
        # get training batch:
        x = ...
        y = ...
        # calculate learning rate:
        current_learning_rate = calculate_learning_rate(counter)
        # train model:
        loss = model.train_on_batch(x, y)    # how to use the current learning rate?

用一些函数来计算学习率。 如何手动设置当前学习率?

如果这篇文章中有错误,我很抱歉,这是我在这里的第一个问题。

已经感谢您提供任何帮助。

编辑

在 2.3.0 中, lr被重命名为learning_rate : link 在旧版本中,您应该改用lr (感谢@Bananach)。

在 keras 后端的帮助下设置值: keras.backend.set_value(model.optimizer.learning_rate, learning_rate) (其中learning_rate是浮点数,所需的学习率)适用于fit方法,应该适用于 train_on_batch:

from keras import backend as K


counter = 0
for epoch in range(EPOCHS):
    for batch_idx in range(0, number_training_samples, BATCH_SIZE):
        # get training batch:
        x = ...
        y = ...
        # calculate learning rate:
        current_learning_rate = calculate_learning_rate(counter)
        # train model:
        K.set_value(model.optimizer.learning_rate, current_learning_rate)  # set new learning_rate
        loss = model.train_on_batch(x, y) 

希望能帮助到你!

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM