简体   繁体   中英

keras: how to use learning rate decay with model.train_on_batch()

In my current project I'm using keras' train_on_batch() function to train since the fit() function does not support the alternating training of generator and discriminator required for GAN's. Using (for example) the Adam optimizer I have to specify the learning rate decay in the constructor with optimizer = Adam(decay=my_decay) and hand this to the models compiling method. This work fine if I use the model's fit() function afterwards, since that takes care of counting the training repetitions internally, but I don't know how I can set this value myself using a construct like

counter = 0
for epoch in range(EPOCHS):
    for batch_idx in range(0, number_training_samples, BATCH_SIZE):
        # get training batch:
        x = ...
        y = ...
        # calculate learning rate:
        current_learning_rate = calculate_learning_rate(counter)
        # train model:
        loss = model.train_on_batch(x, y)    # how to use the current learning rate?

with some function to calculate the learning rate. How can i set the current learning rate manually?

If there are mistakes in this post I'm sorry, it's my first question here.

Thank you already for any help.

EDIT

In 2.3.0, lr was renamed to learning_rate : link . In older versions you should use lr instead (thanks @Bananach).

Set value with a help of keras backend: keras.backend.set_value(model.optimizer.learning_rate, learning_rate) (where learning_rate is a float, desired learning rate) works for the fit method and should work for the train_on_batch:

from keras import backend as K


counter = 0
for epoch in range(EPOCHS):
    for batch_idx in range(0, number_training_samples, BATCH_SIZE):
        # get training batch:
        x = ...
        y = ...
        # calculate learning rate:
        current_learning_rate = calculate_learning_rate(counter)
        # train model:
        K.set_value(model.optimizer.learning_rate, current_learning_rate)  # set new learning_rate
        loss = model.train_on_batch(x, y) 

Hope it helps!

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM