[英]My DQN for a FlappyBird made with pygame in python Don't learn
主文件
# Import Libraries import pygame import sys from FlappyBird import FlappyBird import math import random import numpy as np import matplotlib.pyplot as plt from collections import namedtuple from itertools import count import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torchvision.transforms as T # ---------------------------------------------------------------------------------------------------------------------- # Deep Q-Network class DQN(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(in_features=9, out_features=20) self.out = nn.Linear(in_features=20, out_features=2) def forward(self, t): t = F.relu(self.fc1(t)) # t = F.relu(self.fc2(t)) t = self.out(t) return t # ---------------------------------------------------------------------------------------------------------------------- # Experience class Experience = namedtuple( 'Experience', ('state', 'action', 'next_state', 'reward') ) # ---------------------------------------------------------------------------------------------------------------------- # Replay Memory class ReplayMemory: def __init__(self, capacity): self.capacity = capacity self.memory = [] self.push_count = 0 def push(self, experience): if len(self.memory) < self.capacity: self.memory.append(experience) else: self.memory[self.push_count % self.capacity] = experience self.push_count += 1 def sample(self, batch_size): return random.sample(self.memory, batch_size) def can_provide_sample(self, batch_size): return len(self.memory) >= batch_size # ---------------------------------------------------------------------------------------------------------------------- # Epsilon Greedy Strategy class EpsilonGreedyStrategy: def __init__(self, start, end, decay): self.start = start self.end = end self.decay = decay def get_exploration_rate(self, current_step): return self.end + (self.start - self.end) * math.exp(-current_step * self.decay) # ---------------------------------------------------------------------------------------------------------------------- # Reinforcement Learning Agent class Agent: def __init__(self, strategy, num_actions, device): self.current_step = 0 self.strategy = strategy self.num_actions = num_actions self.device = device def select_action(self, state, policy_net): rate = strategy.get_exploration_rate(self.current_step) self.current_step += 1 if rate > random.random(): action = random.randrange(self.num_actions) return torch.tensor([action]).to(device) # explore else: with torch.no_grad(): return policy_net(state).argmax(dim=1).to(device) # exploit # ---------------------------------------------------------------------------------------------------------------------- # Environment Manager # FlappyBird.py # ---------------------------------------------------------------------------------------------------------------------- # Progression Graph def plot(values, moving_avg_period): plt.figure(2) plt.clf() plt.title('Training...') plt.xlabel('Episode') plt.ylabel('Score') plt.plot(values) moving_avg = get_moving_average(moving_avg_period, values) plt.plot(moving_avg) plt.pause(0.001) print("Episode", len(values), "\n", moving_avg_period, "episode moving avg:", moving_avg[-1]) def get_moving_average(period, values): values = torch.tensor(values, dtype=torch.float) if len(values) >= period: moving_avg = values.unfold(dimension=0, size=period, step=1).mean(dim=1).flatten(start_dim=0) moving_avg = torch.cat((torch.zeros(period - 1), moving_avg)) return moving_avg.numpy() else: moving_avg = torch.zeros(len(values)) return moving_avg.numpy() # ---------------------------------------------------------------------------------------------------------------------- # Tensor processing def extract_tensors(experiences): batch = Experience(*zip(*experiences)) t1 = torch.cat(batch.state) t2 = torch.cat(batch.action) t3 = torch.cat(batch.reward) t4 = torch.cat(batch.next_state) return t1, t2, t3, t4 # ---------------------------------------------------------------------------------------------------------------------- # Q-Value Calculator class QValues: device = torch.device("cpu") @staticmethod def get_current(policy_net, states, actions): return policy_net(states).gather(dim=1, index=actions.unsqueeze(-1)) @staticmethod def get_next(target_net, next_states): final_state_locations = next_states.flatten(start_dim=1).max(dim=1)[0].eq(0).type(torch.bool) non_final_state_locations = (final_state_locations == False) non_final_states = next_states[non_final_state_locations] batch_size = next_states.shape[0] values = torch.zeros(batch_size).to(QValues.device) values[non_final_state_locations] = target_net(non_final_states).max(dim=1)[0].detach() return values # ---------------------------------------------------------------------------------------------------------------------- # Main Program FIRST_TIME = True # True if this is the first time you run this code batch_size = 128 gamma = 0.999 eps_start = 1 eps_end = 0.01 eps_decay = 0.000001 target_update = 3000 memory_size = 500000 lr = 0.0005 num_episodes = 500000 weight_save = 100 lr_update = 500 device = torch.device("cpu") strategy = EpsilonGreedyStrategy(eps_start, eps_end, eps_decay) agent = Agent(strategy, 2, device) memory = ReplayMemory(memory_size) policy_net = DQN().to(device) target_net = DQN().to(device) target_net.eval() optimizer = optim.Adam(params=policy_net.parameters(), lr=lr) if FIRST_TIME: torch.save(optimizer.state_dict(), "./Optimizer_weight/weight.pt") optimizer.load_state_dict(torch.load("./Optimizer_weight/weight.pt")) score = [] if FIRST_TIME: target_net.load_state_dict(policy_net.state_dict()) torch.save(policy_net.state_dict(), "./Policy_weight/weight.pt") torch.save(target_net.state_dict(), "./Target_weight/weight.pt") policy_net.load_state_dict(torch.load("./Policy_weight/weight.pt")) target_net.load_state_dict(torch.load("./Target_weight/weight.pt")) pygame.mixer.pre_init(frequency=44100, size=16, channels=1, buffer=512) pygame.init() em = FlappyBird() for episode in range(num_episodes): em.start_game(em) state = torch.tensor([em.get_state(em)]).to(device).float() while True: action = agent.select_action(state, policy_net) reward, next_state = em.step(em, action) reward = torch.tensor([reward]).to(device).float() next_state = torch.tensor([next_state]).to(device).float() memory.push(Experience(state, action, next_state, reward)) state = next_state if memory.can_provide_sample(batch_size): experiences = memory.sample(batch_size) states, actions, rewards, next_states = extract_tensors(experiences) current_q_values = QValues.get_current(policy_net, states, actions) next_q_values = QValues.get_next(target_net, next_states) target_q_values = (next_q_values * gamma) + rewards loss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1)) optimizer.zero_grad() loss.backward() optimizer.step() if em.is_done(em): score.append(em.get_score(em)) plot(score, 100) print("Exploration: ", round(strategy.get_exploration_rate(agent.current_step) * 100), "%") if memory.can_provide_sample(batch_size): print("Loss: ", loss.data) break if episode % target_update == 0: target_net.load_state_dict(policy_net.state_dict()) torch.save(target_net.state_dict(), "./Target_weight/weight.pt") print("Target Updated") if episode % weight_save == 0: torch.save(policy_net.state_dict(), "./Policy_weight/weight.pt") torch.save(optimizer.state_dict(), "./Optimizer_weight/weight.pt") if episode % lr_update == 0: lr = lr * 0.995 print("Learning Rate: ", lr) pygame.quit() sys.exit()
FlappyBird.py
# Imports import pygame import sys import random import math # ---------------------------------------------------------------------------------------------------------------------- # Game Class class FlappyBird: def __init__(self): self.pipe_is_spawned = False self.screen = pygame.display.set_mode((576, 1024)) self.clock = pygame.time.Clock() self.game_font = pygame.font.Font('Flappy bird assets/04B_19.TTF', 60) # Games Variables self.gravity = 0.6 self.bird_movement = 0 self.game_active = False self.score = 0 self.high_score = 0 self.speed = 100 self.bg_surface = pygame.image.load('Flappy bird assets/sprites/background-day.png').convert() self.bg_surface = pygame.transform.scale2x(self.bg_surface) self.floor_surface = pygame.image.load('Flappy bird assets/sprites/base.png').convert() self.floor_surface = pygame.transform.scale2x(self.floor_surface) self.floor_x_pos = 0 self.bird_downflap = pygame.transform.scale2x( pygame.image.load('Flappy bird assets/sprites/bluebird-downflap.png')).convert_alpha() self.bird_midflap = pygame.transform.scale2x( pygame.image.load('Flappy bird assets/sprites/bluebird-midflap.png')).convert_alpha() self.bird_upflap = pygame.transform.scale2x( pygame.image.load('Flappy bird assets/sprites/bluebird-upflap.png')).convert_alpha() self.bird_frames = [self.bird_downflap, self.bird_midflap, self.bird_upflap] self.bird_index = 0 self.bird_surface = self.bird_frames[self.bird_index] self.bird_rect = self.bird_surface.get_rect(center=(100, 512)) self.BIRDFLAP = pygame.USEREVENT + 1 pygame.time.set_timer(self.BIRDFLAP, 800) self.SPEEDUP = pygame.USEREVENT + 2 pygame.time.set_timer(self.SPEEDUP, 40000) self.pipe_surface = pygame.image.load('Flappy bird assets/sprites/pipe-green.png').convert() self.pipe_surface = pygame.transform.scale2x(self.pipe_surface) self.pipe_list = [] self.SPAWNPIPE = pygame.USEREVENT pygame.time.set_timer(self.SPAWNPIPE, 2400) self.pipe_height = [525, 550, 600, 650, 675] self.point_collider = pygame.Surface((10, 300)) self.point_collider.set_alpha(0) self.collider_list = [] self.game_over_surface = pygame.image.load('Flappy bird assets/sprites/message.png').convert_alpha() self.game_over_surface = pygame.transform.scale2x(self.game_over_surface) self.game_over_rect = self.game_over_surface.get_rect(center=(288, 512)) self.flap_sound = pygame.mixer.Sound('Flappy bird assets/audio/wing.wav') self.death_sound = pygame.mixer.Sound('Flappy bird assets/audio/die.wav') self.hit_sound = pygame.mixer.Sound('Flappy bird assets/audio/hit.wav') self.point_sound = pygame.mixer.Sound('Flappy bird assets/audio/point.wav') self.rewardMultiplier = 1 self.INCREASEREWARD = pygame.USEREVENT pygame.time.set_timer(self.INCREASEREWARD, 2400) @staticmethod def draw_floor(self): self.screen.blit(self.floor_surface, (self.floor_x_pos, 900)) self.screen.blit(self.floor_surface, (self.floor_x_pos + 576, 900)) @staticmethod def create_pipe(self): random_pipe_pos = random.choice(self.pipe_height) bottomPipe = self.pipe_surface.get_rect(midtop=(700, random_pipe_pos)) topPipe = self.pipe_surface.get_rect(midbottom=(700, random_pipe_pos - 300)) point_collider_rect = self.point_collider.get_rect(midbottom=(700, random_pipe_pos)) return bottomPipe, topPipe, point_collider_rect @staticmethod def move_pipes(self, pipes, colliders): for pipe in pipes: pipe.centerx -= self.speed for collider in colliders: collider.centerx -= self.speed return pipes, colliders @staticmethod def draw_pipes(self, pipes, colliders): for pipe in pipes: if pipe.bottom >= 1024: self.screen.blit(self.pipe_surface, pipe) else: flip_pipe = pygame.transform.flip(self.pipe_surface, False, True) self.screen.blit(flip_pipe, pipe) for collider in colliders: self.screen.blit(self.point_collider, collider) @staticmethod def check_collision(self, pipes): for pipe in pipes: if self.bird_rect.colliderect(pipe): self.hit_sound.play() return False if self.bird_rect.top <= -100 or self.bird_rect.bottom >= 900: self.death_sound.play() return False return True @staticmethod def check_pipe_reached(self, colliders): for collider in colliders: if self.bird_rect.colliderect(collider): colliders.remove(collider) self.pipe_is_spawned = False return colliders, True return colliders, False @staticmethod def rotate_bird(self, bird): new_bird = pygame.transform.rotozoom(bird, -self.bird_movement * 2, 1) return new_bird @staticmethod def bird_animation(self): new_bird = self.bird_frames[self.bird_index] new_bird_rect = new_bird.get_rect(center=(100, self.bird_rect.centery)) return new_bird, new_bird_rect @staticmethod def score_display(self): score_surface = self.game_font.render(str(self.score), True, (255, 255, 255)) score_rect = score_surface.get_rect(center=(288, 100)) self.screen.blit(score_surface, score_rect) @staticmethod def is_done(self): if not self.game_active: self.pipe_is_spawned = False return not self.game_active @staticmethod def get_score(self): return self.score @staticmethod def start_game(self): self.game_active = True self.pipe_list.clear() self.collider_list.clear() self.bird_rect.center = (100, 512) self.bird_movement = 0 self.score = 0 self.speed = 4 self.pipe_is_spawned = False @staticmethod def check_event(self, action): for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() sys.exit() if action == 1: self.bird_movement = 0 self.bird_movement -= 15 self.flap_sound.play() if event.type == self.SPAWNPIPE and self.game_active: bottom_pipe, top_pipe, win_collider = self.create_pipe(self) self.pipe_list.append(bottom_pipe) self.pipe_list.append(top_pipe) self.collider_list.append(win_collider) self.pipe_is_spawned = True if event.type == self.BIRDFLAP: if self.bird_index < 2: self.bird_index += 1 else: self.bird_index = 0 self.bird_surface, self.bird_rect = self.bird_animation(self) if event.type == self.SPEEDUP and self.game_active: if self.speed < 8: self.speed += 0.05 if event.type == self.INCREASEREWARD and self.game_active: self.rewardMultiplier += 0.01 @staticmethod def is_game_active(self): # Bird self.bird_movement += self.gravity rotated_bird = self.rotate_bird(self, self.bird_surface) self.bird_rect.centery += self.bird_movement self.screen.blit(rotated_bird, self.bird_rect) self.game_active = self.check_collision(self, self.pipe_list) collider_list, pipe_reached = self.check_pipe_reached(self, self.collider_list) if pipe_reached: self.point_sound.play() self.score += 1 # Pipes pipe_list, collider_list = self.move_pipes(self, self.pipe_list, self.collider_list) self.draw_pipes(self, pipe_list, collider_list) self.score_display(self) @staticmethod def draw_frame(self): self.screen.blit(self.bg_surface, (0, 0)) if self.game_active: self.is_game_active(self) else: self.screen.blit(self.game_over_surface, self.game_over_rect) # Floor self.floor_x_pos -= 4 self.draw_floor(self) if self.floor_x_pos <= -576: self.floor_x_pos = 0 pygame.display.update() self.clock.tick(60) @staticmethod def get_reward(self): point = 0.1 * self.rewardMultiplier if not self.check_collision: point = -100 * self.rewardMultiplier elif self.check_pipe_reached: point = 50 * self.rewardMultiplier return point @staticmethod def get_state(self): if self.pipe_is_spawned: position = [self.bird_rect.centery, self.collider_list[0].centerx, self.collider_list[0].centery, math.sqrt((self.collider_list[0].centerx - self.bird_rect.centerx) ** 2 + (self.collider_list[0].centery - self.bird_rect.centery) ** 2), self.collider_list[0].centery - self.bird_rect.centery, self.collider_list[0].centerx - self.bird_rect.centerx, self.bird_movement, 900 - self.bird_rect.bottom, 900 - self.collider_list[0].centery] else: position = [self.bird_rect.centery, 0, 0, 0, 0, 0, 900 - self.bird_rect.bottom, 50, self.bird_movement] return position @staticmethod def step(self, action): self.check_event(self, action) self.draw_frame(self) point = self.get_reward(self) self.is_done(self) position = self.get_state(self) return point, position
我會說先嘗試在您的 model 架構中添加另一層,因為它非常小。 此外,您從該環境中接收到的觀察值具有一些非常大的整數,因此您可能希望對這些特征中的每一個進行歸一化。 另一個不相關的建議是使用collections.deque(maxlen)
中的隊列機制來實現重播緩沖區更有效。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.