简体   繁体   中英

My DQN for a FlappyBird made with pygame in python Don't learn


I created a DQN which learn to play a FlappyBird made with pygame, but my problem is that when i run the code it seams that th DQN don't learn, the loss not decreasing and the bird don't achieve to pass the first pipe.
I already tried to change some of th hyperparameters but nothing change, i change the leearning rate, i tried to decrease the learning rate through the time. I tried to increase: the update of the target weight, the size of the replay memory, the batch size. I think tthat the problem can comme from thee flappyBird but I don't see what can i change.

main.py

# Import Libraries import pygame import sys from FlappyBird import FlappyBird import math import random import numpy as np import matplotlib.pyplot as plt from collections import namedtuple from itertools import count import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torchvision.transforms as T # ---------------------------------------------------------------------------------------------------------------------- # Deep Q-Network class DQN(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(in_features=9, out_features=20) self.out = nn.Linear(in_features=20, out_features=2) def forward(self, t): t = F.relu(self.fc1(t)) # t = F.relu(self.fc2(t)) t = self.out(t) return t # ---------------------------------------------------------------------------------------------------------------------- # Experience class Experience = namedtuple( 'Experience', ('state', 'action', 'next_state', 'reward') ) # ---------------------------------------------------------------------------------------------------------------------- # Replay Memory class ReplayMemory: def __init__(self, capacity): self.capacity = capacity self.memory = [] self.push_count = 0 def push(self, experience): if len(self.memory) < self.capacity: self.memory.append(experience) else: self.memory[self.push_count % self.capacity] = experience self.push_count += 1 def sample(self, batch_size): return random.sample(self.memory, batch_size) def can_provide_sample(self, batch_size): return len(self.memory) >= batch_size # ---------------------------------------------------------------------------------------------------------------------- # Epsilon Greedy Strategy class EpsilonGreedyStrategy: def __init__(self, start, end, decay): self.start = start self.end = end self.decay = decay def get_exploration_rate(self, current_step): return self.end + (self.start - self.end) * math.exp(-current_step * self.decay) # ---------------------------------------------------------------------------------------------------------------------- # Reinforcement Learning Agent class Agent: def __init__(self, strategy, num_actions, device): self.current_step = 0 self.strategy = strategy self.num_actions = num_actions self.device = device def select_action(self, state, policy_net): rate = strategy.get_exploration_rate(self.current_step) self.current_step += 1 if rate > random.random(): action = random.randrange(self.num_actions) return torch.tensor([action]).to(device) # explore else: with torch.no_grad(): return policy_net(state).argmax(dim=1).to(device) # exploit # ---------------------------------------------------------------------------------------------------------------------- # Environment Manager # FlappyBird.py # ---------------------------------------------------------------------------------------------------------------------- # Progression Graph def plot(values, moving_avg_period): plt.figure(2) plt.clf() plt.title('Training...') plt.xlabel('Episode') plt.ylabel('Score') plt.plot(values) moving_avg = get_moving_average(moving_avg_period, values) plt.plot(moving_avg) plt.pause(0.001) print("Episode", len(values), "\n", moving_avg_period, "episode moving avg:", moving_avg[-1]) def get_moving_average(period, values): values = torch.tensor(values, dtype=torch.float) if len(values) >= period: moving_avg = values.unfold(dimension=0, size=period, step=1).mean(dim=1).flatten(start_dim=0) moving_avg = torch.cat((torch.zeros(period - 1), moving_avg)) return moving_avg.numpy() else: moving_avg = torch.zeros(len(values)) return moving_avg.numpy() # ---------------------------------------------------------------------------------------------------------------------- # Tensor processing def extract_tensors(experiences): batch = Experience(*zip(*experiences)) t1 = torch.cat(batch.state) t2 = torch.cat(batch.action) t3 = torch.cat(batch.reward) t4 = torch.cat(batch.next_state) return t1, t2, t3, t4 # ---------------------------------------------------------------------------------------------------------------------- # Q-Value Calculator class QValues: device = torch.device("cpu") @staticmethod def get_current(policy_net, states, actions): return policy_net(states).gather(dim=1, index=actions.unsqueeze(-1)) @staticmethod def get_next(target_net, next_states): final_state_locations = next_states.flatten(start_dim=1).max(dim=1)[0].eq(0).type(torch.bool) non_final_state_locations = (final_state_locations == False) non_final_states = next_states[non_final_state_locations] batch_size = next_states.shape[0] values = torch.zeros(batch_size).to(QValues.device) values[non_final_state_locations] = target_net(non_final_states).max(dim=1)[0].detach() return values # ---------------------------------------------------------------------------------------------------------------------- # Main Program FIRST_TIME = True # True if this is the first time you run this code batch_size = 128 gamma = 0.999 eps_start = 1 eps_end = 0.01 eps_decay = 0.000001 target_update = 3000 memory_size = 500000 lr = 0.0005 num_episodes = 500000 weight_save = 100 lr_update = 500 device = torch.device("cpu") strategy = EpsilonGreedyStrategy(eps_start, eps_end, eps_decay) agent = Agent(strategy, 2, device) memory = ReplayMemory(memory_size) policy_net = DQN().to(device) target_net = DQN().to(device) target_net.eval() optimizer = optim.Adam(params=policy_net.parameters(), lr=lr) if FIRST_TIME: torch.save(optimizer.state_dict(), "./Optimizer_weight/weight.pt") optimizer.load_state_dict(torch.load("./Optimizer_weight/weight.pt")) score = [] if FIRST_TIME: target_net.load_state_dict(policy_net.state_dict()) torch.save(policy_net.state_dict(), "./Policy_weight/weight.pt") torch.save(target_net.state_dict(), "./Target_weight/weight.pt") policy_net.load_state_dict(torch.load("./Policy_weight/weight.pt")) target_net.load_state_dict(torch.load("./Target_weight/weight.pt")) pygame.mixer.pre_init(frequency=44100, size=16, channels=1, buffer=512) pygame.init() em = FlappyBird() for episode in range(num_episodes): em.start_game(em) state = torch.tensor([em.get_state(em)]).to(device).float() while True: action = agent.select_action(state, policy_net) reward, next_state = em.step(em, action) reward = torch.tensor([reward]).to(device).float() next_state = torch.tensor([next_state]).to(device).float() memory.push(Experience(state, action, next_state, reward)) state = next_state if memory.can_provide_sample(batch_size): experiences = memory.sample(batch_size) states, actions, rewards, next_states = extract_tensors(experiences) current_q_values = QValues.get_current(policy_net, states, actions) next_q_values = QValues.get_next(target_net, next_states) target_q_values = (next_q_values * gamma) + rewards loss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1)) optimizer.zero_grad() loss.backward() optimizer.step() if em.is_done(em): score.append(em.get_score(em)) plot(score, 100) print("Exploration: ", round(strategy.get_exploration_rate(agent.current_step) * 100), "%") if memory.can_provide_sample(batch_size): print("Loss: ", loss.data) break if episode % target_update == 0: target_net.load_state_dict(policy_net.state_dict()) torch.save(target_net.state_dict(), "./Target_weight/weight.pt") print("Target Updated") if episode % weight_save == 0: torch.save(policy_net.state_dict(), "./Policy_weight/weight.pt") torch.save(optimizer.state_dict(), "./Optimizer_weight/weight.pt") if episode % lr_update == 0: lr = lr * 0.995 print("Learning Rate: ", lr) pygame.quit() sys.exit()

FlappyBird.py

 # Imports import pygame import sys import random import math # ---------------------------------------------------------------------------------------------------------------------- # Game Class class FlappyBird: def __init__(self): self.pipe_is_spawned = False self.screen = pygame.display.set_mode((576, 1024)) self.clock = pygame.time.Clock() self.game_font = pygame.font.Font('Flappy bird assets/04B_19.TTF', 60) # Games Variables self.gravity = 0.6 self.bird_movement = 0 self.game_active = False self.score = 0 self.high_score = 0 self.speed = 100 self.bg_surface = pygame.image.load('Flappy bird assets/sprites/background-day.png').convert() self.bg_surface = pygame.transform.scale2x(self.bg_surface) self.floor_surface = pygame.image.load('Flappy bird assets/sprites/base.png').convert() self.floor_surface = pygame.transform.scale2x(self.floor_surface) self.floor_x_pos = 0 self.bird_downflap = pygame.transform.scale2x( pygame.image.load('Flappy bird assets/sprites/bluebird-downflap.png')).convert_alpha() self.bird_midflap = pygame.transform.scale2x( pygame.image.load('Flappy bird assets/sprites/bluebird-midflap.png')).convert_alpha() self.bird_upflap = pygame.transform.scale2x( pygame.image.load('Flappy bird assets/sprites/bluebird-upflap.png')).convert_alpha() self.bird_frames = [self.bird_downflap, self.bird_midflap, self.bird_upflap] self.bird_index = 0 self.bird_surface = self.bird_frames[self.bird_index] self.bird_rect = self.bird_surface.get_rect(center=(100, 512)) self.BIRDFLAP = pygame.USEREVENT + 1 pygame.time.set_timer(self.BIRDFLAP, 800) self.SPEEDUP = pygame.USEREVENT + 2 pygame.time.set_timer(self.SPEEDUP, 40000) self.pipe_surface = pygame.image.load('Flappy bird assets/sprites/pipe-green.png').convert() self.pipe_surface = pygame.transform.scale2x(self.pipe_surface) self.pipe_list = [] self.SPAWNPIPE = pygame.USEREVENT pygame.time.set_timer(self.SPAWNPIPE, 2400) self.pipe_height = [525, 550, 600, 650, 675] self.point_collider = pygame.Surface((10, 300)) self.point_collider.set_alpha(0) self.collider_list = [] self.game_over_surface = pygame.image.load('Flappy bird assets/sprites/message.png').convert_alpha() self.game_over_surface = pygame.transform.scale2x(self.game_over_surface) self.game_over_rect = self.game_over_surface.get_rect(center=(288, 512)) self.flap_sound = pygame.mixer.Sound('Flappy bird assets/audio/wing.wav') self.death_sound = pygame.mixer.Sound('Flappy bird assets/audio/die.wav') self.hit_sound = pygame.mixer.Sound('Flappy bird assets/audio/hit.wav') self.point_sound = pygame.mixer.Sound('Flappy bird assets/audio/point.wav') self.rewardMultiplier = 1 self.INCREASEREWARD = pygame.USEREVENT pygame.time.set_timer(self.INCREASEREWARD, 2400) @staticmethod def draw_floor(self): self.screen.blit(self.floor_surface, (self.floor_x_pos, 900)) self.screen.blit(self.floor_surface, (self.floor_x_pos + 576, 900)) @staticmethod def create_pipe(self): random_pipe_pos = random.choice(self.pipe_height) bottomPipe = self.pipe_surface.get_rect(midtop=(700, random_pipe_pos)) topPipe = self.pipe_surface.get_rect(midbottom=(700, random_pipe_pos - 300)) point_collider_rect = self.point_collider.get_rect(midbottom=(700, random_pipe_pos)) return bottomPipe, topPipe, point_collider_rect @staticmethod def move_pipes(self, pipes, colliders): for pipe in pipes: pipe.centerx -= self.speed for collider in colliders: collider.centerx -= self.speed return pipes, colliders @staticmethod def draw_pipes(self, pipes, colliders): for pipe in pipes: if pipe.bottom >= 1024: self.screen.blit(self.pipe_surface, pipe) else: flip_pipe = pygame.transform.flip(self.pipe_surface, False, True) self.screen.blit(flip_pipe, pipe) for collider in colliders: self.screen.blit(self.point_collider, collider) @staticmethod def check_collision(self, pipes): for pipe in pipes: if self.bird_rect.colliderect(pipe): self.hit_sound.play() return False if self.bird_rect.top <= -100 or self.bird_rect.bottom >= 900: self.death_sound.play() return False return True @staticmethod def check_pipe_reached(self, colliders): for collider in colliders: if self.bird_rect.colliderect(collider): colliders.remove(collider) self.pipe_is_spawned = False return colliders, True return colliders, False @staticmethod def rotate_bird(self, bird): new_bird = pygame.transform.rotozoom(bird, -self.bird_movement * 2, 1) return new_bird @staticmethod def bird_animation(self): new_bird = self.bird_frames[self.bird_index] new_bird_rect = new_bird.get_rect(center=(100, self.bird_rect.centery)) return new_bird, new_bird_rect @staticmethod def score_display(self): score_surface = self.game_font.render(str(self.score), True, (255, 255, 255)) score_rect = score_surface.get_rect(center=(288, 100)) self.screen.blit(score_surface, score_rect) @staticmethod def is_done(self): if not self.game_active: self.pipe_is_spawned = False return not self.game_active @staticmethod def get_score(self): return self.score @staticmethod def start_game(self): self.game_active = True self.pipe_list.clear() self.collider_list.clear() self.bird_rect.center = (100, 512) self.bird_movement = 0 self.score = 0 self.speed = 4 self.pipe_is_spawned = False @staticmethod def check_event(self, action): for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() sys.exit() if action == 1: self.bird_movement = 0 self.bird_movement -= 15 self.flap_sound.play() if event.type == self.SPAWNPIPE and self.game_active: bottom_pipe, top_pipe, win_collider = self.create_pipe(self) self.pipe_list.append(bottom_pipe) self.pipe_list.append(top_pipe) self.collider_list.append(win_collider) self.pipe_is_spawned = True if event.type == self.BIRDFLAP: if self.bird_index < 2: self.bird_index += 1 else: self.bird_index = 0 self.bird_surface, self.bird_rect = self.bird_animation(self) if event.type == self.SPEEDUP and self.game_active: if self.speed < 8: self.speed += 0.05 if event.type == self.INCREASEREWARD and self.game_active: self.rewardMultiplier += 0.01 @staticmethod def is_game_active(self): # Bird self.bird_movement += self.gravity rotated_bird = self.rotate_bird(self, self.bird_surface) self.bird_rect.centery += self.bird_movement self.screen.blit(rotated_bird, self.bird_rect) self.game_active = self.check_collision(self, self.pipe_list) collider_list, pipe_reached = self.check_pipe_reached(self, self.collider_list) if pipe_reached: self.point_sound.play() self.score += 1 # Pipes pipe_list, collider_list = self.move_pipes(self, self.pipe_list, self.collider_list) self.draw_pipes(self, pipe_list, collider_list) self.score_display(self) @staticmethod def draw_frame(self): self.screen.blit(self.bg_surface, (0, 0)) if self.game_active: self.is_game_active(self) else: self.screen.blit(self.game_over_surface, self.game_over_rect) # Floor self.floor_x_pos -= 4 self.draw_floor(self) if self.floor_x_pos <= -576: self.floor_x_pos = 0 pygame.display.update() self.clock.tick(60) @staticmethod def get_reward(self): point = 0.1 * self.rewardMultiplier if not self.check_collision: point = -100 * self.rewardMultiplier elif self.check_pipe_reached: point = 50 * self.rewardMultiplier return point @staticmethod def get_state(self): if self.pipe_is_spawned: position = [self.bird_rect.centery, self.collider_list[0].centerx, self.collider_list[0].centery, math.sqrt((self.collider_list[0].centerx - self.bird_rect.centerx) ** 2 + (self.collider_list[0].centery - self.bird_rect.centery) ** 2), self.collider_list[0].centery - self.bird_rect.centery, self.collider_list[0].centerx - self.bird_rect.centerx, self.bird_movement, 900 - self.bird_rect.bottom, 900 - self.collider_list[0].centery] else: position = [self.bird_rect.centery, 0, 0, 0, 0, 0, 900 - self.bird_rect.bottom, 50, self.bird_movement] return position @staticmethod def step(self, action): self.check_event(self, action) self.draw_frame(self) point = self.get_reward(self) self.is_done(self) position = self.get_state(self) return point, position

I'd say try adding another layer to your model architecture first, because it's very small. Also, the observations your receiving from that environment have some very large integers, so you might want to normalize each of those features. Another suggestion that is unrelated is that it is more efficient to implement the replay buffer using a queue mechanism from collections.deque(maxlen) .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM