I tried to implement the most simple Deep Q Learning algorithm. I think, I've implemented it right and know that Deep Q Learning struggles with divergences but the reward is declining very fast and the loss is diverging. I would be grateful if someone could help me pointing out the right hyperparameters or if I implemented the algorithm wrong. I've tried a lot of hyperparameter combinations and also changing the complexity of the QNet.
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import collections
import numpy as np
import matplotlib.pyplot as plt
import gym
from torch.nn.modules.linear import Linear
from torch.nn.modules.loss import MSELoss
class ReplayBuffer:
def __init__(self, max_replay_size, batch_size):
self.max_replay_size = max_replay_size
self.batch_size = batch_size
self.buffer = collections.deque()
def push(self, *transition):
if len(self.buffer) == self.max_replay_size:
self.buffer.popleft()
self.buffer.append(transition)
def sample_batch(self):
indices = np.random.choice(len(self.buffer), self.batch_size, replace = False)
batch = [self.buffer[index] for index in indices]
state, action, reward, next_state, done = zip(*batch)
state = np.array(state)
action = np.array(action)
reward = np.array(reward)
next_state = np.array(next_state)
done = np.array(done)
return state, action, reward, next_state, done
def __len__(self):
return len(self.buffer)
class QNet(nn.Module):
def __init__(self, state_dim, action_dim):
super(QNet, self).__init__()
self.linear1 = Linear(in_features = state_dim, out_features = 64)
self.linear2 = Linear(in_features = 64, out_features = action_dim)
def forward(self, x):
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
return x
def train(replay_buffer, model, target_model, discount_factor, mse, optimizer):
state, action, reward, next_state, _ = replay_buffer.sample_batch()
state, next_state = torch.tensor(state, dtype = torch.float), torch.tensor(next_state,
dtype = torch.float)
# Compute Q Value and Target Q Value
q_values = model(state).gather(1, torch.tensor(action, dtype = torch.int64).unsqueeze(-1))
with torch.no_grad():
max_next_q_values = target_model(next_state).detach().max(1)[0]
q_target_value = torch.tensor(reward, dtype = torch.float) + discount_factor *
max_next_q_values
optimizer.zero_grad()
loss = mse(q_values, q_target_value.unsqueeze(1))
loss.backward()
optimizer.step()
return loss.item()
def main():
# Define Hyperparameters and Parameters
EPISODES = 10000
MAX_REPLAY_SIZE = 10000
BATCH_SIZE = 32
EPSILON = 1.0
MIN_EPSILON = 0.05
DISCOUNT_FACTOR = 0.95
DECAY_RATE = 0.99
LEARNING_RATE = 1e-3
SYNCHRONISATION = 33
EVALUATION = 32
# Initialize Environment, Model, Target-Model, Optimizer, Loss Function and Replay Buffer
env = gym.make("CartPole-v0")
model = QNet(state_dim = env.observation_space.shape[0], action_dim =
env.action_space.n)
target_model = QNet(state_dim = env.observation_space.shape[0], action_dim =
env.action_space.n)
target_model.load_state_dict(model.state_dict())
optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)
mse = MSELoss()
replay_buffer = ReplayBuffer(max_replay_size = MAX_REPLAY_SIZE, batch_size = BATCH_SIZE)
while len(replay_buffer) != MAX_REPLAY_SIZE:
state = env.reset()
done = False
while done != True:
action = env.action_space.sample()
next_state, reward, done, _ = env.step(action)
replay_buffer.push(state, action, reward, next_state, done)
state = next_state
# Begin with the Main Loop where the QNet is trained
count_until_synchronisation = 0
count_until_evaluation = 0
history = {'Episode': [], 'Reward': [], 'Loss': []}
for episode in range(EPISODES):
total_reward = 0.0
total_loss = 0.0
state = env.reset()
iterations = 0
done = False
while done != True:
count_until_synchronisation += 1
count_until_evaluation += 1
# Take an action
if np.random.rand(1) < EPSILON:
action = env.action_space.sample()
else:
with torch.no_grad():
output = model(torch.tensor(state, dtype = torch.float)).numpy()
action = np.argmax(output)
# Observe new state and reward + store into replay_buffer
next_state, reward, done, _ = env.step(action)
total_reward += reward
replay_buffer.push(state, action, reward, next_state, done)
state = next_state
if count_until_synchronisation % SYNCHRONISATION == 0:
target_model.load_state_dict(model.state_dict())
if count_until_evaluation % EVALUATION == 0:
loss = train(replay_buffer = replay_buffer, model = model, target_model =
target_model, discount_factor = DISCOUNT_FACTOR,
mse = mse, optimizer = optimizer)
total_loss += loss
iterations += 1
print (f"Episode {episode} is concluded in {iterations} iterations with a total reward
of {total_reward}")
if EPSILON > MIN_EPSILON:
EPSILON *= DECAY_RATE
history['Episode'].append(episode)
history['Reward'].append(total_reward)
history['Loss'].append(total_loss)
# Plot the Loss + Reward per Episode
fig, ax = plt.subplots(figsize = (10, 6))
ax.plot(history['Episode'], history['Reward'], label = "Reward")
ax.set_xlabel('Episodes', fontsize = 15)
ax.set_ylabel('Total Reward per Episode', fontsize = 15)
plt.legend(prop = {'size': 15})
plt.show()
fig, ax = plt.subplots(figsize = (10, 6))
ax.plot(history['Episode'], history['Loss'], label = "Loss")
ax.set_xlabel('Episodes', fontsize = 15)
ax.set_ylabel('Total Loss per Episode', fontsize = 15)
plt.legend(prop = {'size': 15})
plt.show()
if __name__ == "__main__":
main()
Your code looks fine and well written, the hyperparams seem reasonable (except maybe for the update frequency which may be too low), I think that the Q network is quite small with a single dense layer.
A deeper model could likely do better (probably not more than 3-4 layers though), but you said that you already tried different network sizes.
Another thing that comes to mind is the target update. You are doing an hard update every n steps; a soft update may help a bit, but I wouldn't count on it.
You can also try lowering the learning rate a bit, but I imagine you already did that.
My suggestions are:
Sadly DQN is not ideal and won't converge for many problems, but it should be able to solve cartpole.
Your code looks fine, I think your hyperparameters are not ideal. I would change two, potentially three things:
Hope this helps, good luck:)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.