[英]My DQN for a FlappyBird made with pygame in python Don't learn
main.py主文件
# Import Libraries import pygame import sys from FlappyBird import FlappyBird import math import random import numpy as np import matplotlib.pyplot as plt from collections import namedtuple from itertools import count import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torchvision.transforms as T # ---------------------------------------------------------------------------------------------------------------------- # Deep Q-Network class DQN(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(in_features=9, out_features=20) self.out = nn.Linear(in_features=20, out_features=2) def forward(self, t): t = F.relu(self.fc1(t)) # t = F.relu(self.fc2(t)) t = self.out(t) return t # ---------------------------------------------------------------------------------------------------------------------- # Experience class Experience = namedtuple( 'Experience', ('state', 'action', 'next_state', 'reward') ) # ---------------------------------------------------------------------------------------------------------------------- # Replay Memory class ReplayMemory: def __init__(self, capacity): self.capacity = capacity self.memory = [] self.push_count = 0 def push(self, experience): if len(self.memory) < self.capacity: self.memory.append(experience) else: self.memory[self.push_count % self.capacity] = experience self.push_count += 1 def sample(self, batch_size): return random.sample(self.memory, batch_size) def can_provide_sample(self, batch_size): return len(self.memory) >= batch_size # ---------------------------------------------------------------------------------------------------------------------- # Epsilon Greedy Strategy class EpsilonGreedyStrategy: def __init__(self, start, end, decay): self.start = start self.end = end self.decay = decay def get_exploration_rate(self, current_step): return self.end + (self.start - self.end) * math.exp(-current_step * self.decay) # ---------------------------------------------------------------------------------------------------------------------- # Reinforcement Learning Agent class Agent: def __init__(self, strategy, num_actions, device): self.current_step = 0 self.strategy = strategy self.num_actions = num_actions self.device = device def select_action(self, state, policy_net): rate = strategy.get_exploration_rate(self.current_step) self.current_step += 1 if rate > random.random(): action = random.randrange(self.num_actions) return torch.tensor([action]).to(device) # explore else: with torch.no_grad(): return policy_net(state).argmax(dim=1).to(device) # exploit # ---------------------------------------------------------------------------------------------------------------------- # Environment Manager # FlappyBird.py # ---------------------------------------------------------------------------------------------------------------------- # Progression Graph def plot(values, moving_avg_period): plt.figure(2) plt.clf() plt.title('Training...') plt.xlabel('Episode') plt.ylabel('Score') plt.plot(values) moving_avg = get_moving_average(moving_avg_period, values) plt.plot(moving_avg) plt.pause(0.001) print("Episode", len(values), "\n", moving_avg_period, "episode moving avg:", moving_avg[-1]) def get_moving_average(period, values): values = torch.tensor(values, dtype=torch.float) if len(values) >= period: moving_avg = values.unfold(dimension=0, size=period, step=1).mean(dim=1).flatten(start_dim=0) moving_avg = torch.cat((torch.zeros(period - 1), moving_avg)) return moving_avg.numpy() else: moving_avg = torch.zeros(len(values)) return moving_avg.numpy() # ---------------------------------------------------------------------------------------------------------------------- # Tensor processing def extract_tensors(experiences): batch = Experience(*zip(*experiences)) t1 = torch.cat(batch.state) t2 = torch.cat(batch.action) t3 = torch.cat(batch.reward) t4 = torch.cat(batch.next_state) return t1, t2, t3, t4 # ---------------------------------------------------------------------------------------------------------------------- # Q-Value Calculator class QValues: device = torch.device("cpu") @staticmethod def get_current(policy_net, states, actions): return policy_net(states).gather(dim=1, index=actions.unsqueeze(-1)) @staticmethod def get_next(target_net, next_states): final_state_locations = next_states.flatten(start_dim=1).max(dim=1)[0].eq(0).type(torch.bool) non_final_state_locations = (final_state_locations == False) non_final_states = next_states[non_final_state_locations] batch_size = next_states.shape[0] values = torch.zeros(batch_size).to(QValues.device) values[non_final_state_locations] = target_net(non_final_states).max(dim=1)[0].detach() return values # ---------------------------------------------------------------------------------------------------------------------- # Main Program FIRST_TIME = True # True if this is the first time you run this code batch_size = 128 gamma = 0.999 eps_start = 1 eps_end = 0.01 eps_decay = 0.000001 target_update = 3000 memory_size = 500000 lr = 0.0005 num_episodes = 500000 weight_save = 100 lr_update = 500 device = torch.device("cpu") strategy = EpsilonGreedyStrategy(eps_start, eps_end, eps_decay) agent = Agent(strategy, 2, device) memory = ReplayMemory(memory_size) policy_net = DQN().to(device) target_net = DQN().to(device) target_net.eval() optimizer = optim.Adam(params=policy_net.parameters(), lr=lr) if FIRST_TIME: torch.save(optimizer.state_dict(), "./Optimizer_weight/weight.pt") optimizer.load_state_dict(torch.load("./Optimizer_weight/weight.pt")) score = [] if FIRST_TIME: target_net.load_state_dict(policy_net.state_dict()) torch.save(policy_net.state_dict(), "./Policy_weight/weight.pt") torch.save(target_net.state_dict(), "./Target_weight/weight.pt") policy_net.load_state_dict(torch.load("./Policy_weight/weight.pt")) target_net.load_state_dict(torch.load("./Target_weight/weight.pt")) pygame.mixer.pre_init(frequency=44100, size=16, channels=1, buffer=512) pygame.init() em = FlappyBird() for episode in range(num_episodes): em.start_game(em) state = torch.tensor([em.get_state(em)]).to(device).float() while True: action = agent.select_action(state, policy_net) reward, next_state = em.step(em, action) reward = torch.tensor([reward]).to(device).float() next_state = torch.tensor([next_state]).to(device).float() memory.push(Experience(state, action, next_state, reward)) state = next_state if memory.can_provide_sample(batch_size): experiences = memory.sample(batch_size) states, actions, rewards, next_states = extract_tensors(experiences) current_q_values = QValues.get_current(policy_net, states, actions) next_q_values = QValues.get_next(target_net, next_states) target_q_values = (next_q_values * gamma) + rewards loss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1)) optimizer.zero_grad() loss.backward() optimizer.step() if em.is_done(em): score.append(em.get_score(em)) plot(score, 100) print("Exploration: ", round(strategy.get_exploration_rate(agent.current_step) * 100), "%") if memory.can_provide_sample(batch_size): print("Loss: ", loss.data) break if episode % target_update == 0: target_net.load_state_dict(policy_net.state_dict()) torch.save(target_net.state_dict(), "./Target_weight/weight.pt") print("Target Updated") if episode % weight_save == 0: torch.save(policy_net.state_dict(), "./Policy_weight/weight.pt") torch.save(optimizer.state_dict(), "./Optimizer_weight/weight.pt") if episode % lr_update == 0: lr = lr * 0.995 print("Learning Rate: ", lr) pygame.quit() sys.exit()
FlappyBird.py FlappyBird.py
# Imports import pygame import sys import random import math # ---------------------------------------------------------------------------------------------------------------------- # Game Class class FlappyBird: def __init__(self): self.pipe_is_spawned = False self.screen = pygame.display.set_mode((576, 1024)) self.clock = pygame.time.Clock() self.game_font = pygame.font.Font('Flappy bird assets/04B_19.TTF', 60) # Games Variables self.gravity = 0.6 self.bird_movement = 0 self.game_active = False self.score = 0 self.high_score = 0 self.speed = 100 self.bg_surface = pygame.image.load('Flappy bird assets/sprites/background-day.png').convert() self.bg_surface = pygame.transform.scale2x(self.bg_surface) self.floor_surface = pygame.image.load('Flappy bird assets/sprites/base.png').convert() self.floor_surface = pygame.transform.scale2x(self.floor_surface) self.floor_x_pos = 0 self.bird_downflap = pygame.transform.scale2x( pygame.image.load('Flappy bird assets/sprites/bluebird-downflap.png')).convert_alpha() self.bird_midflap = pygame.transform.scale2x( pygame.image.load('Flappy bird assets/sprites/bluebird-midflap.png')).convert_alpha() self.bird_upflap = pygame.transform.scale2x( pygame.image.load('Flappy bird assets/sprites/bluebird-upflap.png')).convert_alpha() self.bird_frames = [self.bird_downflap, self.bird_midflap, self.bird_upflap] self.bird_index = 0 self.bird_surface = self.bird_frames[self.bird_index] self.bird_rect = self.bird_surface.get_rect(center=(100, 512)) self.BIRDFLAP = pygame.USEREVENT + 1 pygame.time.set_timer(self.BIRDFLAP, 800) self.SPEEDUP = pygame.USEREVENT + 2 pygame.time.set_timer(self.SPEEDUP, 40000) self.pipe_surface = pygame.image.load('Flappy bird assets/sprites/pipe-green.png').convert() self.pipe_surface = pygame.transform.scale2x(self.pipe_surface) self.pipe_list = [] self.SPAWNPIPE = pygame.USEREVENT pygame.time.set_timer(self.SPAWNPIPE, 2400) self.pipe_height = [525, 550, 600, 650, 675] self.point_collider = pygame.Surface((10, 300)) self.point_collider.set_alpha(0) self.collider_list = [] self.game_over_surface = pygame.image.load('Flappy bird assets/sprites/message.png').convert_alpha() self.game_over_surface = pygame.transform.scale2x(self.game_over_surface) self.game_over_rect = self.game_over_surface.get_rect(center=(288, 512)) self.flap_sound = pygame.mixer.Sound('Flappy bird assets/audio/wing.wav') self.death_sound = pygame.mixer.Sound('Flappy bird assets/audio/die.wav') self.hit_sound = pygame.mixer.Sound('Flappy bird assets/audio/hit.wav') self.point_sound = pygame.mixer.Sound('Flappy bird assets/audio/point.wav') self.rewardMultiplier = 1 self.INCREASEREWARD = pygame.USEREVENT pygame.time.set_timer(self.INCREASEREWARD, 2400) @staticmethod def draw_floor(self): self.screen.blit(self.floor_surface, (self.floor_x_pos, 900)) self.screen.blit(self.floor_surface, (self.floor_x_pos + 576, 900)) @staticmethod def create_pipe(self): random_pipe_pos = random.choice(self.pipe_height) bottomPipe = self.pipe_surface.get_rect(midtop=(700, random_pipe_pos)) topPipe = self.pipe_surface.get_rect(midbottom=(700, random_pipe_pos - 300)) point_collider_rect = self.point_collider.get_rect(midbottom=(700, random_pipe_pos)) return bottomPipe, topPipe, point_collider_rect @staticmethod def move_pipes(self, pipes, colliders): for pipe in pipes: pipe.centerx -= self.speed for collider in colliders: collider.centerx -= self.speed return pipes, colliders @staticmethod def draw_pipes(self, pipes, colliders): for pipe in pipes: if pipe.bottom >= 1024: self.screen.blit(self.pipe_surface, pipe) else: flip_pipe = pygame.transform.flip(self.pipe_surface, False, True) self.screen.blit(flip_pipe, pipe) for collider in colliders: self.screen.blit(self.point_collider, collider) @staticmethod def check_collision(self, pipes): for pipe in pipes: if self.bird_rect.colliderect(pipe): self.hit_sound.play() return False if self.bird_rect.top <= -100 or self.bird_rect.bottom >= 900: self.death_sound.play() return False return True @staticmethod def check_pipe_reached(self, colliders): for collider in colliders: if self.bird_rect.colliderect(collider): colliders.remove(collider) self.pipe_is_spawned = False return colliders, True return colliders, False @staticmethod def rotate_bird(self, bird): new_bird = pygame.transform.rotozoom(bird, -self.bird_movement * 2, 1) return new_bird @staticmethod def bird_animation(self): new_bird = self.bird_frames[self.bird_index] new_bird_rect = new_bird.get_rect(center=(100, self.bird_rect.centery)) return new_bird, new_bird_rect @staticmethod def score_display(self): score_surface = self.game_font.render(str(self.score), True, (255, 255, 255)) score_rect = score_surface.get_rect(center=(288, 100)) self.screen.blit(score_surface, score_rect) @staticmethod def is_done(self): if not self.game_active: self.pipe_is_spawned = False return not self.game_active @staticmethod def get_score(self): return self.score @staticmethod def start_game(self): self.game_active = True self.pipe_list.clear() self.collider_list.clear() self.bird_rect.center = (100, 512) self.bird_movement = 0 self.score = 0 self.speed = 4 self.pipe_is_spawned = False @staticmethod def check_event(self, action): for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() sys.exit() if action == 1: self.bird_movement = 0 self.bird_movement -= 15 self.flap_sound.play() if event.type == self.SPAWNPIPE and self.game_active: bottom_pipe, top_pipe, win_collider = self.create_pipe(self) self.pipe_list.append(bottom_pipe) self.pipe_list.append(top_pipe) self.collider_list.append(win_collider) self.pipe_is_spawned = True if event.type == self.BIRDFLAP: if self.bird_index < 2: self.bird_index += 1 else: self.bird_index = 0 self.bird_surface, self.bird_rect = self.bird_animation(self) if event.type == self.SPEEDUP and self.game_active: if self.speed < 8: self.speed += 0.05 if event.type == self.INCREASEREWARD and self.game_active: self.rewardMultiplier += 0.01 @staticmethod def is_game_active(self): # Bird self.bird_movement += self.gravity rotated_bird = self.rotate_bird(self, self.bird_surface) self.bird_rect.centery += self.bird_movement self.screen.blit(rotated_bird, self.bird_rect) self.game_active = self.check_collision(self, self.pipe_list) collider_list, pipe_reached = self.check_pipe_reached(self, self.collider_list) if pipe_reached: self.point_sound.play() self.score += 1 # Pipes pipe_list, collider_list = self.move_pipes(self, self.pipe_list, self.collider_list) self.draw_pipes(self, pipe_list, collider_list) self.score_display(self) @staticmethod def draw_frame(self): self.screen.blit(self.bg_surface, (0, 0)) if self.game_active: self.is_game_active(self) else: self.screen.blit(self.game_over_surface, self.game_over_rect) # Floor self.floor_x_pos -= 4 self.draw_floor(self) if self.floor_x_pos <= -576: self.floor_x_pos = 0 pygame.display.update() self.clock.tick(60) @staticmethod def get_reward(self): point = 0.1 * self.rewardMultiplier if not self.check_collision: point = -100 * self.rewardMultiplier elif self.check_pipe_reached: point = 50 * self.rewardMultiplier return point @staticmethod def get_state(self): if self.pipe_is_spawned: position = [self.bird_rect.centery, self.collider_list[0].centerx, self.collider_list[0].centery, math.sqrt((self.collider_list[0].centerx - self.bird_rect.centerx) ** 2 + (self.collider_list[0].centery - self.bird_rect.centery) ** 2), self.collider_list[0].centery - self.bird_rect.centery, self.collider_list[0].centerx - self.bird_rect.centerx, self.bird_movement, 900 - self.bird_rect.bottom, 900 - self.collider_list[0].centery] else: position = [self.bird_rect.centery, 0, 0, 0, 0, 0, 900 - self.bird_rect.bottom, 50, self.bird_movement] return position @staticmethod def step(self, action): self.check_event(self, action) self.draw_frame(self) point = self.get_reward(self) self.is_done(self) position = self.get_state(self) return point, position
I'd say try adding another layer to your model architecture first, because it's very small.我会说先尝试在您的 model 架构中添加另一层,因为它非常小。 Also, the observations your receiving from that environment have some very large integers, so you might want to normalize each of those features.
此外,您从该环境中接收到的观察值具有一些非常大的整数,因此您可能希望对这些特征中的每一个进行归一化。 Another suggestion that is unrelated is that it is more efficient to implement the replay buffer using a queue mechanism from
collections.deque(maxlen)
.另一个不相关的建议是使用
collections.deque(maxlen)
中的队列机制来实现重播缓冲区更有效。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.