简体   繁体   中英

How to implement gradient ascent in a Keras DQN

Have built a Reinforcement Learning DQN with variable length sequences as inputs, and positive and negative rewards calculated for actions. Some problem with my DQN model in Keras means that although the model runs, average rewards over time decrease, over single and multiple cycles of epsilon. This does not change even after significant period of training. epsilon 单周期,平均奖励递减

epsilon 多次循环,平均奖励递减

My thinking is that this is due to using MeanSquareError in Keras as the Loss function (minimising error). So I am trying to implement gradient ascent (to maximise reward). How to do this in Keras? My current model is:

model = Sequential()
inp = (env.NUM_TIMEPERIODS, env.NUM_FEATURES)
model.add(Input(shape=inp))  # 'a shape tuple(integers), not including batch-size
model.add(Masking(mask_value=0., input_shape=inp))

model.add(LSTM(env.NUM_FEATURES, input_shape=inp, return_sequences=True))
model.add(LSTM(env.NUM_FEATURES))
model.add(Dense(env.NUM_FEATURES))
model.add(Dense(4))

model.compile(loss='mse,
              optimizer=Adam(lr=LEARNING_RATE, decay=DECAY),
              metrics=[tf.keras.losses.MeanSquaredError()])

In trying to implement gradient ascent, by 'flipping' the gradient (as negative or inverse loss?), I have tried various loss definitions:

loss=-'mse'    
loss=-tf.keras.losses.MeanSquaredError()    
loss=1/tf.keras.losses.MeanSquaredError()

but these all generate bad operand [for unary] errors.

How to adapt current Keras model to maximise rewards ? Or is this gradient ascent not even the problem? Could it be some issue with the action policy?

Writing a custom loss function

Here is the loss function you want

@tf.function
def positive_mse(y_true, y_pred):
    return -1 * tf.keras.losses.MSE(y_true, y_pred)

And then your compile line becomes

model.compile(loss=positive_mse,
          optimizer=Adam(lr=LEARNING_RATE, decay=DECAY),
          metrics=[tf.keras.losses.MeanSquaredError()])

Please note : use loss=positive_mse and not loss=positive_mse() . That's not a typo. This is because you need to pass the function, not the results of executing the function.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM