简体   繁体   中英

Learning rate and weight decay schedule in Tensorflow SGDW optimizer

I'm trying to reproduce part of this paper with TensorFlow, the problem is that the authors use SGD with weight decay, cutting the learning rate to 1/10 every 30 epochs.

TensorFlow documentation says that

when applying a decay to the learning rate, be sure to manually apply the decay to the weight_decay as well

So I tried with

schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate = 0.003,
                                                   decay_rate = 0.1,
                                                   decay_steps = steps_per_epoch*30,
                                                   staircase = True
)
optimizer = tfa.optimizers.SGDW(learning_rate = schedule,
                            weight_decay = schedule,
                            momentum = 0.9
)

(steps_per_epoch previously initialized)
As I would with Keras SGD, this however doesn't work and raises a "TypeError: Expected float32" for the decay_weight parameter. What's the correct way to achieve the target behaviour?

You are getting an error because you are using keras ExponentialDecay inside tensorflow add-on optimizer SGDW .

As per the paper hyper-parameters are

  1. weight decay of 0.001
  2. momentum of 0.9
  3. starting learning rate is 0.003 which is reduced by a factor of 10 after 30 epochs

So why not use LearningRateScheduler to reduce it by factor to 10 evey 30 epochs

Sample Code

model = tf.keras.Sequential([
    tf.keras.layers.Dense(8, input_shape=(10,)),
    tf.keras.layers.Dense(4, activation='relu'),
    tf.keras.layers.Dense(3, activation='softmax'),
])

X = np.random.randn(10,10)
y = np.random.randint(0,4,(10,3))

model.compile(
    optimizer=tfa.optimizers.SGDW(
        weight_decay=0.001,
        momentum=0.9,
        learning_rate=0.003),
      loss=tf.keras.losses.categorical_crossentropy)

def scheduler(epoch, lr):
  if epoch  % 30 == 0:
    lr = lr*0.1  
  return lr

callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
model.fit(X, y, callbacks=[callback], epochs=100)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM