I'm trying to make a double dqn network for cartpole-v0, but the network doesn't seem to be working as expected and stagnates at around 8-9 reward. What am I doing wrong?
Each step in the learning phase:
def make_step(model, target_model, optimizer, criterion, observation, action, reward, next_observation):
inp_obv = torch.Tensor(observation)
q = model(inp_obv)
q_argmax = torch.argmax(q.data)
q = q[action]
inp_next_obv = torch.Tensor(next_observation)
q_next = target_model(inp_next_obv)
q_a_next = q_next[q_argmax]
#LHS of the double DQN equation
obv_reward = q
#RHS of the double DQN equation
target_reward = torch.Tensor([reward]) + GAMMA*q_a_next.detach()
#Backprop
loss = criterion(obv_reward, target_reward) #MSELoss
loss.backward()
Code wrapping make_step:
optimizer.zero_grad() #RMSprop on net
if e%2 == 0:
target_net.load_state_dict(net.state_dict())
for i in range(len(data)):
observation, action, reward, next_observation = data[i]
make_step(net, target_net, optimizer, criterion, observation, action, reward, next_observation)
GAMMA *= GAMMA
optimizer.step()
What am I doing wrong? Thank you.
Increase the target network update frequency can solve the problem.
optimizer.zero_grad() #RMSprop on net
if e % 100 == 0:
target_net.load_state_dict(net.state_dict())
for i in range(len(data)):
observation, action, reward, next_observation = data[i]
make_step(net, target_net, optimizer, criterion, observation, action, reward, next_observation)
GAMMA *= GAMMA
optimizer.step()
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.