簡體   English   中英

我用 python 中的 pygame 制作的 FlappyBird 的 DQN 不要學習

[英]My DQN for a FlappyBird made with pygame in python Don't learn


我創建了一個 DQN,它學會玩用 pygame 制作的 FlappyBird,但我的問題是,當我運行代碼時,它會發現 DQN 不學習,損失沒有減少,而且鳥沒有達到通過第一個 pipe .
我已經嘗試改變一些超參數但沒有任何改變,我改變了學習率,我試圖隨着時間的推移降低學習率。 我嘗試增加:目標權重的更新,重放的大小 memory,批量大小。 我認為問題可能來自你的 flappyBird,但我不知道我能改變什么。

主文件

# Import Libraries import pygame import sys from FlappyBird import FlappyBird import math import random import numpy as np import matplotlib.pyplot as plt from collections import namedtuple from itertools import count import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torchvision.transforms as T # ---------------------------------------------------------------------------------------------------------------------- # Deep Q-Network class DQN(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(in_features=9, out_features=20) self.out = nn.Linear(in_features=20, out_features=2) def forward(self, t): t = F.relu(self.fc1(t)) # t = F.relu(self.fc2(t)) t = self.out(t) return t # ---------------------------------------------------------------------------------------------------------------------- # Experience class Experience = namedtuple( 'Experience', ('state', 'action', 'next_state', 'reward') ) # ---------------------------------------------------------------------------------------------------------------------- # Replay Memory class ReplayMemory: def __init__(self, capacity): self.capacity = capacity self.memory = [] self.push_count = 0 def push(self, experience): if len(self.memory) < self.capacity: self.memory.append(experience) else: self.memory[self.push_count % self.capacity] = experience self.push_count += 1 def sample(self, batch_size): return random.sample(self.memory, batch_size) def can_provide_sample(self, batch_size): return len(self.memory) >= batch_size # ---------------------------------------------------------------------------------------------------------------------- # Epsilon Greedy Strategy class EpsilonGreedyStrategy: def __init__(self, start, end, decay): self.start = start self.end = end self.decay = decay def get_exploration_rate(self, current_step): return self.end + (self.start - self.end) * math.exp(-current_step * self.decay) # ---------------------------------------------------------------------------------------------------------------------- # Reinforcement Learning Agent class Agent: def __init__(self, strategy, num_actions, device): self.current_step = 0 self.strategy = strategy self.num_actions = num_actions self.device = device def select_action(self, state, policy_net): rate = strategy.get_exploration_rate(self.current_step) self.current_step += 1 if rate > random.random(): action = random.randrange(self.num_actions) return torch.tensor([action]).to(device) # explore else: with torch.no_grad(): return policy_net(state).argmax(dim=1).to(device) # exploit # ---------------------------------------------------------------------------------------------------------------------- # Environment Manager # FlappyBird.py # ---------------------------------------------------------------------------------------------------------------------- # Progression Graph def plot(values, moving_avg_period): plt.figure(2) plt.clf() plt.title('Training...') plt.xlabel('Episode') plt.ylabel('Score') plt.plot(values) moving_avg = get_moving_average(moving_avg_period, values) plt.plot(moving_avg) plt.pause(0.001) print("Episode", len(values), "\n", moving_avg_period, "episode moving avg:", moving_avg[-1]) def get_moving_average(period, values): values = torch.tensor(values, dtype=torch.float) if len(values) >= period: moving_avg = values.unfold(dimension=0, size=period, step=1).mean(dim=1).flatten(start_dim=0) moving_avg = torch.cat((torch.zeros(period - 1), moving_avg)) return moving_avg.numpy() else: moving_avg = torch.zeros(len(values)) return moving_avg.numpy() # ---------------------------------------------------------------------------------------------------------------------- # Tensor processing def extract_tensors(experiences): batch = Experience(*zip(*experiences)) t1 = torch.cat(batch.state) t2 = torch.cat(batch.action) t3 = torch.cat(batch.reward) t4 = torch.cat(batch.next_state) return t1, t2, t3, t4 # ---------------------------------------------------------------------------------------------------------------------- # Q-Value Calculator class QValues: device = torch.device("cpu") @staticmethod def get_current(policy_net, states, actions): return policy_net(states).gather(dim=1, index=actions.unsqueeze(-1)) @staticmethod def get_next(target_net, next_states): final_state_locations = next_states.flatten(start_dim=1).max(dim=1)[0].eq(0).type(torch.bool) non_final_state_locations = (final_state_locations == False) non_final_states = next_states[non_final_state_locations] batch_size = next_states.shape[0] values = torch.zeros(batch_size).to(QValues.device) values[non_final_state_locations] = target_net(non_final_states).max(dim=1)[0].detach() return values # ---------------------------------------------------------------------------------------------------------------------- # Main Program FIRST_TIME = True # True if this is the first time you run this code batch_size = 128 gamma = 0.999 eps_start = 1 eps_end = 0.01 eps_decay = 0.000001 target_update = 3000 memory_size = 500000 lr = 0.0005 num_episodes = 500000 weight_save = 100 lr_update = 500 device = torch.device("cpu") strategy = EpsilonGreedyStrategy(eps_start, eps_end, eps_decay) agent = Agent(strategy, 2, device) memory = ReplayMemory(memory_size) policy_net = DQN().to(device) target_net = DQN().to(device) target_net.eval() optimizer = optim.Adam(params=policy_net.parameters(), lr=lr) if FIRST_TIME: torch.save(optimizer.state_dict(), "./Optimizer_weight/weight.pt") optimizer.load_state_dict(torch.load("./Optimizer_weight/weight.pt")) score = [] if FIRST_TIME: target_net.load_state_dict(policy_net.state_dict()) torch.save(policy_net.state_dict(), "./Policy_weight/weight.pt") torch.save(target_net.state_dict(), "./Target_weight/weight.pt") policy_net.load_state_dict(torch.load("./Policy_weight/weight.pt")) target_net.load_state_dict(torch.load("./Target_weight/weight.pt")) pygame.mixer.pre_init(frequency=44100, size=16, channels=1, buffer=512) pygame.init() em = FlappyBird() for episode in range(num_episodes): em.start_game(em) state = torch.tensor([em.get_state(em)]).to(device).float() while True: action = agent.select_action(state, policy_net) reward, next_state = em.step(em, action) reward = torch.tensor([reward]).to(device).float() next_state = torch.tensor([next_state]).to(device).float() memory.push(Experience(state, action, next_state, reward)) state = next_state if memory.can_provide_sample(batch_size): experiences = memory.sample(batch_size) states, actions, rewards, next_states = extract_tensors(experiences) current_q_values = QValues.get_current(policy_net, states, actions) next_q_values = QValues.get_next(target_net, next_states) target_q_values = (next_q_values * gamma) + rewards loss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1)) optimizer.zero_grad() loss.backward() optimizer.step() if em.is_done(em): score.append(em.get_score(em)) plot(score, 100) print("Exploration: ", round(strategy.get_exploration_rate(agent.current_step) * 100), "%") if memory.can_provide_sample(batch_size): print("Loss: ", loss.data) break if episode % target_update == 0: target_net.load_state_dict(policy_net.state_dict()) torch.save(target_net.state_dict(), "./Target_weight/weight.pt") print("Target Updated") if episode % weight_save == 0: torch.save(policy_net.state_dict(), "./Policy_weight/weight.pt") torch.save(optimizer.state_dict(), "./Optimizer_weight/weight.pt") if episode % lr_update == 0: lr = lr * 0.995 print("Learning Rate: ", lr) pygame.quit() sys.exit()

FlappyBird.py

 # Imports import pygame import sys import random import math # ---------------------------------------------------------------------------------------------------------------------- # Game Class class FlappyBird: def __init__(self): self.pipe_is_spawned = False self.screen = pygame.display.set_mode((576, 1024)) self.clock = pygame.time.Clock() self.game_font = pygame.font.Font('Flappy bird assets/04B_19.TTF', 60) # Games Variables self.gravity = 0.6 self.bird_movement = 0 self.game_active = False self.score = 0 self.high_score = 0 self.speed = 100 self.bg_surface = pygame.image.load('Flappy bird assets/sprites/background-day.png').convert() self.bg_surface = pygame.transform.scale2x(self.bg_surface) self.floor_surface = pygame.image.load('Flappy bird assets/sprites/base.png').convert() self.floor_surface = pygame.transform.scale2x(self.floor_surface) self.floor_x_pos = 0 self.bird_downflap = pygame.transform.scale2x( pygame.image.load('Flappy bird assets/sprites/bluebird-downflap.png')).convert_alpha() self.bird_midflap = pygame.transform.scale2x( pygame.image.load('Flappy bird assets/sprites/bluebird-midflap.png')).convert_alpha() self.bird_upflap = pygame.transform.scale2x( pygame.image.load('Flappy bird assets/sprites/bluebird-upflap.png')).convert_alpha() self.bird_frames = [self.bird_downflap, self.bird_midflap, self.bird_upflap] self.bird_index = 0 self.bird_surface = self.bird_frames[self.bird_index] self.bird_rect = self.bird_surface.get_rect(center=(100, 512)) self.BIRDFLAP = pygame.USEREVENT + 1 pygame.time.set_timer(self.BIRDFLAP, 800) self.SPEEDUP = pygame.USEREVENT + 2 pygame.time.set_timer(self.SPEEDUP, 40000) self.pipe_surface = pygame.image.load('Flappy bird assets/sprites/pipe-green.png').convert() self.pipe_surface = pygame.transform.scale2x(self.pipe_surface) self.pipe_list = [] self.SPAWNPIPE = pygame.USEREVENT pygame.time.set_timer(self.SPAWNPIPE, 2400) self.pipe_height = [525, 550, 600, 650, 675] self.point_collider = pygame.Surface((10, 300)) self.point_collider.set_alpha(0) self.collider_list = [] self.game_over_surface = pygame.image.load('Flappy bird assets/sprites/message.png').convert_alpha() self.game_over_surface = pygame.transform.scale2x(self.game_over_surface) self.game_over_rect = self.game_over_surface.get_rect(center=(288, 512)) self.flap_sound = pygame.mixer.Sound('Flappy bird assets/audio/wing.wav') self.death_sound = pygame.mixer.Sound('Flappy bird assets/audio/die.wav') self.hit_sound = pygame.mixer.Sound('Flappy bird assets/audio/hit.wav') self.point_sound = pygame.mixer.Sound('Flappy bird assets/audio/point.wav') self.rewardMultiplier = 1 self.INCREASEREWARD = pygame.USEREVENT pygame.time.set_timer(self.INCREASEREWARD, 2400) @staticmethod def draw_floor(self): self.screen.blit(self.floor_surface, (self.floor_x_pos, 900)) self.screen.blit(self.floor_surface, (self.floor_x_pos + 576, 900)) @staticmethod def create_pipe(self): random_pipe_pos = random.choice(self.pipe_height) bottomPipe = self.pipe_surface.get_rect(midtop=(700, random_pipe_pos)) topPipe = self.pipe_surface.get_rect(midbottom=(700, random_pipe_pos - 300)) point_collider_rect = self.point_collider.get_rect(midbottom=(700, random_pipe_pos)) return bottomPipe, topPipe, point_collider_rect @staticmethod def move_pipes(self, pipes, colliders): for pipe in pipes: pipe.centerx -= self.speed for collider in colliders: collider.centerx -= self.speed return pipes, colliders @staticmethod def draw_pipes(self, pipes, colliders): for pipe in pipes: if pipe.bottom >= 1024: self.screen.blit(self.pipe_surface, pipe) else: flip_pipe = pygame.transform.flip(self.pipe_surface, False, True) self.screen.blit(flip_pipe, pipe) for collider in colliders: self.screen.blit(self.point_collider, collider) @staticmethod def check_collision(self, pipes): for pipe in pipes: if self.bird_rect.colliderect(pipe): self.hit_sound.play() return False if self.bird_rect.top <= -100 or self.bird_rect.bottom >= 900: self.death_sound.play() return False return True @staticmethod def check_pipe_reached(self, colliders): for collider in colliders: if self.bird_rect.colliderect(collider): colliders.remove(collider) self.pipe_is_spawned = False return colliders, True return colliders, False @staticmethod def rotate_bird(self, bird): new_bird = pygame.transform.rotozoom(bird, -self.bird_movement * 2, 1) return new_bird @staticmethod def bird_animation(self): new_bird = self.bird_frames[self.bird_index] new_bird_rect = new_bird.get_rect(center=(100, self.bird_rect.centery)) return new_bird, new_bird_rect @staticmethod def score_display(self): score_surface = self.game_font.render(str(self.score), True, (255, 255, 255)) score_rect = score_surface.get_rect(center=(288, 100)) self.screen.blit(score_surface, score_rect) @staticmethod def is_done(self): if not self.game_active: self.pipe_is_spawned = False return not self.game_active @staticmethod def get_score(self): return self.score @staticmethod def start_game(self): self.game_active = True self.pipe_list.clear() self.collider_list.clear() self.bird_rect.center = (100, 512) self.bird_movement = 0 self.score = 0 self.speed = 4 self.pipe_is_spawned = False @staticmethod def check_event(self, action): for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() sys.exit() if action == 1: self.bird_movement = 0 self.bird_movement -= 15 self.flap_sound.play() if event.type == self.SPAWNPIPE and self.game_active: bottom_pipe, top_pipe, win_collider = self.create_pipe(self) self.pipe_list.append(bottom_pipe) self.pipe_list.append(top_pipe) self.collider_list.append(win_collider) self.pipe_is_spawned = True if event.type == self.BIRDFLAP: if self.bird_index < 2: self.bird_index += 1 else: self.bird_index = 0 self.bird_surface, self.bird_rect = self.bird_animation(self) if event.type == self.SPEEDUP and self.game_active: if self.speed < 8: self.speed += 0.05 if event.type == self.INCREASEREWARD and self.game_active: self.rewardMultiplier += 0.01 @staticmethod def is_game_active(self): # Bird self.bird_movement += self.gravity rotated_bird = self.rotate_bird(self, self.bird_surface) self.bird_rect.centery += self.bird_movement self.screen.blit(rotated_bird, self.bird_rect) self.game_active = self.check_collision(self, self.pipe_list) collider_list, pipe_reached = self.check_pipe_reached(self, self.collider_list) if pipe_reached: self.point_sound.play() self.score += 1 # Pipes pipe_list, collider_list = self.move_pipes(self, self.pipe_list, self.collider_list) self.draw_pipes(self, pipe_list, collider_list) self.score_display(self) @staticmethod def draw_frame(self): self.screen.blit(self.bg_surface, (0, 0)) if self.game_active: self.is_game_active(self) else: self.screen.blit(self.game_over_surface, self.game_over_rect) # Floor self.floor_x_pos -= 4 self.draw_floor(self) if self.floor_x_pos <= -576: self.floor_x_pos = 0 pygame.display.update() self.clock.tick(60) @staticmethod def get_reward(self): point = 0.1 * self.rewardMultiplier if not self.check_collision: point = -100 * self.rewardMultiplier elif self.check_pipe_reached: point = 50 * self.rewardMultiplier return point @staticmethod def get_state(self): if self.pipe_is_spawned: position = [self.bird_rect.centery, self.collider_list[0].centerx, self.collider_list[0].centery, math.sqrt((self.collider_list[0].centerx - self.bird_rect.centerx) ** 2 + (self.collider_list[0].centery - self.bird_rect.centery) ** 2), self.collider_list[0].centery - self.bird_rect.centery, self.collider_list[0].centerx - self.bird_rect.centerx, self.bird_movement, 900 - self.bird_rect.bottom, 900 - self.collider_list[0].centery] else: position = [self.bird_rect.centery, 0, 0, 0, 0, 0, 900 - self.bird_rect.bottom, 50, self.bird_movement] return position @staticmethod def step(self, action): self.check_event(self, action) self.draw_frame(self) point = self.get_reward(self) self.is_done(self) position = self.get_state(self) return point, position

我會說先嘗試在您的 model 架構中添加另一層,因為它非常小。 此外,您從該環境中接收到的觀察值具有一些非常大的整數,因此您可能希望對這些特征中的每一個進行歸一化。 另一個不相關的建議是使用collections.deque(maxlen)中的隊列機制來實現重播緩沖區更有效。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM