简体   繁体   English

我用 python 中的 pygame 制作的 FlappyBird 的 DQN 不要学习

[英]My DQN for a FlappyBird made with pygame in python Don't learn


I created a DQN which learn to play a FlappyBird made with pygame, but my problem is that when i run the code it seams that th DQN don't learn, the loss not decreasing and the bird don't achieve to pass the first pipe. 我创建了一个 DQN,它学会玩用 pygame 制作的 FlappyBird,但我的问题是,当我运行代码时,它会发现 DQN 不学习,损失没有减少,而且鸟没有达到通过第一个 pipe .
I already tried to change some of th hyperparameters but nothing change, i change the leearning rate, i tried to decrease the learning rate through the time. 我已经尝试改变一些超参数但没有任何改变,我改变了学习率,我试图随着时间的推移降低学习率。 I tried to increase: the update of the target weight, the size of the replay memory, the batch size. 我尝试增加:目标权重的更新,重放的大小 memory,批量大小。 I think tthat the problem can comme from thee flappyBird but I don't see what can i change. 我认为问题可能来自你的 flappyBird,但我不知道我能改变什么。

main.py主文件

# Import Libraries import pygame import sys from FlappyBird import FlappyBird import math import random import numpy as np import matplotlib.pyplot as plt from collections import namedtuple from itertools import count import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torchvision.transforms as T # ---------------------------------------------------------------------------------------------------------------------- # Deep Q-Network class DQN(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(in_features=9, out_features=20) self.out = nn.Linear(in_features=20, out_features=2) def forward(self, t): t = F.relu(self.fc1(t)) # t = F.relu(self.fc2(t)) t = self.out(t) return t # ---------------------------------------------------------------------------------------------------------------------- # Experience class Experience = namedtuple( 'Experience', ('state', 'action', 'next_state', 'reward') ) # ---------------------------------------------------------------------------------------------------------------------- # Replay Memory class ReplayMemory: def __init__(self, capacity): self.capacity = capacity self.memory = [] self.push_count = 0 def push(self, experience): if len(self.memory) < self.capacity: self.memory.append(experience) else: self.memory[self.push_count % self.capacity] = experience self.push_count += 1 def sample(self, batch_size): return random.sample(self.memory, batch_size) def can_provide_sample(self, batch_size): return len(self.memory) >= batch_size # ---------------------------------------------------------------------------------------------------------------------- # Epsilon Greedy Strategy class EpsilonGreedyStrategy: def __init__(self, start, end, decay): self.start = start self.end = end self.decay = decay def get_exploration_rate(self, current_step): return self.end + (self.start - self.end) * math.exp(-current_step * self.decay) # ---------------------------------------------------------------------------------------------------------------------- # Reinforcement Learning Agent class Agent: def __init__(self, strategy, num_actions, device): self.current_step = 0 self.strategy = strategy self.num_actions = num_actions self.device = device def select_action(self, state, policy_net): rate = strategy.get_exploration_rate(self.current_step) self.current_step += 1 if rate > random.random(): action = random.randrange(self.num_actions) return torch.tensor([action]).to(device) # explore else: with torch.no_grad(): return policy_net(state).argmax(dim=1).to(device) # exploit # ---------------------------------------------------------------------------------------------------------------------- # Environment Manager # FlappyBird.py # ---------------------------------------------------------------------------------------------------------------------- # Progression Graph def plot(values, moving_avg_period): plt.figure(2) plt.clf() plt.title('Training...') plt.xlabel('Episode') plt.ylabel('Score') plt.plot(values) moving_avg = get_moving_average(moving_avg_period, values) plt.plot(moving_avg) plt.pause(0.001) print("Episode", len(values), "\n", moving_avg_period, "episode moving avg:", moving_avg[-1]) def get_moving_average(period, values): values = torch.tensor(values, dtype=torch.float) if len(values) >= period: moving_avg = values.unfold(dimension=0, size=period, step=1).mean(dim=1).flatten(start_dim=0) moving_avg = torch.cat((torch.zeros(period - 1), moving_avg)) return moving_avg.numpy() else: moving_avg = torch.zeros(len(values)) return moving_avg.numpy() # ---------------------------------------------------------------------------------------------------------------------- # Tensor processing def extract_tensors(experiences): batch = Experience(*zip(*experiences)) t1 = torch.cat(batch.state) t2 = torch.cat(batch.action) t3 = torch.cat(batch.reward) t4 = torch.cat(batch.next_state) return t1, t2, t3, t4 # ---------------------------------------------------------------------------------------------------------------------- # Q-Value Calculator class QValues: device = torch.device("cpu") @staticmethod def get_current(policy_net, states, actions): return policy_net(states).gather(dim=1, index=actions.unsqueeze(-1)) @staticmethod def get_next(target_net, next_states): final_state_locations = next_states.flatten(start_dim=1).max(dim=1)[0].eq(0).type(torch.bool) non_final_state_locations = (final_state_locations == False) non_final_states = next_states[non_final_state_locations] batch_size = next_states.shape[0] values = torch.zeros(batch_size).to(QValues.device) values[non_final_state_locations] = target_net(non_final_states).max(dim=1)[0].detach() return values # ---------------------------------------------------------------------------------------------------------------------- # Main Program FIRST_TIME = True # True if this is the first time you run this code batch_size = 128 gamma = 0.999 eps_start = 1 eps_end = 0.01 eps_decay = 0.000001 target_update = 3000 memory_size = 500000 lr = 0.0005 num_episodes = 500000 weight_save = 100 lr_update = 500 device = torch.device("cpu") strategy = EpsilonGreedyStrategy(eps_start, eps_end, eps_decay) agent = Agent(strategy, 2, device) memory = ReplayMemory(memory_size) policy_net = DQN().to(device) target_net = DQN().to(device) target_net.eval() optimizer = optim.Adam(params=policy_net.parameters(), lr=lr) if FIRST_TIME: torch.save(optimizer.state_dict(), "./Optimizer_weight/weight.pt") optimizer.load_state_dict(torch.load("./Optimizer_weight/weight.pt")) score = [] if FIRST_TIME: target_net.load_state_dict(policy_net.state_dict()) torch.save(policy_net.state_dict(), "./Policy_weight/weight.pt") torch.save(target_net.state_dict(), "./Target_weight/weight.pt") policy_net.load_state_dict(torch.load("./Policy_weight/weight.pt")) target_net.load_state_dict(torch.load("./Target_weight/weight.pt")) pygame.mixer.pre_init(frequency=44100, size=16, channels=1, buffer=512) pygame.init() em = FlappyBird() for episode in range(num_episodes): em.start_game(em) state = torch.tensor([em.get_state(em)]).to(device).float() while True: action = agent.select_action(state, policy_net) reward, next_state = em.step(em, action) reward = torch.tensor([reward]).to(device).float() next_state = torch.tensor([next_state]).to(device).float() memory.push(Experience(state, action, next_state, reward)) state = next_state if memory.can_provide_sample(batch_size): experiences = memory.sample(batch_size) states, actions, rewards, next_states = extract_tensors(experiences) current_q_values = QValues.get_current(policy_net, states, actions) next_q_values = QValues.get_next(target_net, next_states) target_q_values = (next_q_values * gamma) + rewards loss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1)) optimizer.zero_grad() loss.backward() optimizer.step() if em.is_done(em): score.append(em.get_score(em)) plot(score, 100) print("Exploration: ", round(strategy.get_exploration_rate(agent.current_step) * 100), "%") if memory.can_provide_sample(batch_size): print("Loss: ", loss.data) break if episode % target_update == 0: target_net.load_state_dict(policy_net.state_dict()) torch.save(target_net.state_dict(), "./Target_weight/weight.pt") print("Target Updated") if episode % weight_save == 0: torch.save(policy_net.state_dict(), "./Policy_weight/weight.pt") torch.save(optimizer.state_dict(), "./Optimizer_weight/weight.pt") if episode % lr_update == 0: lr = lr * 0.995 print("Learning Rate: ", lr) pygame.quit() sys.exit()

FlappyBird.py FlappyBird.py

 # Imports import pygame import sys import random import math # ---------------------------------------------------------------------------------------------------------------------- # Game Class class FlappyBird: def __init__(self): self.pipe_is_spawned = False self.screen = pygame.display.set_mode((576, 1024)) self.clock = pygame.time.Clock() self.game_font = pygame.font.Font('Flappy bird assets/04B_19.TTF', 60) # Games Variables self.gravity = 0.6 self.bird_movement = 0 self.game_active = False self.score = 0 self.high_score = 0 self.speed = 100 self.bg_surface = pygame.image.load('Flappy bird assets/sprites/background-day.png').convert() self.bg_surface = pygame.transform.scale2x(self.bg_surface) self.floor_surface = pygame.image.load('Flappy bird assets/sprites/base.png').convert() self.floor_surface = pygame.transform.scale2x(self.floor_surface) self.floor_x_pos = 0 self.bird_downflap = pygame.transform.scale2x( pygame.image.load('Flappy bird assets/sprites/bluebird-downflap.png')).convert_alpha() self.bird_midflap = pygame.transform.scale2x( pygame.image.load('Flappy bird assets/sprites/bluebird-midflap.png')).convert_alpha() self.bird_upflap = pygame.transform.scale2x( pygame.image.load('Flappy bird assets/sprites/bluebird-upflap.png')).convert_alpha() self.bird_frames = [self.bird_downflap, self.bird_midflap, self.bird_upflap] self.bird_index = 0 self.bird_surface = self.bird_frames[self.bird_index] self.bird_rect = self.bird_surface.get_rect(center=(100, 512)) self.BIRDFLAP = pygame.USEREVENT + 1 pygame.time.set_timer(self.BIRDFLAP, 800) self.SPEEDUP = pygame.USEREVENT + 2 pygame.time.set_timer(self.SPEEDUP, 40000) self.pipe_surface = pygame.image.load('Flappy bird assets/sprites/pipe-green.png').convert() self.pipe_surface = pygame.transform.scale2x(self.pipe_surface) self.pipe_list = [] self.SPAWNPIPE = pygame.USEREVENT pygame.time.set_timer(self.SPAWNPIPE, 2400) self.pipe_height = [525, 550, 600, 650, 675] self.point_collider = pygame.Surface((10, 300)) self.point_collider.set_alpha(0) self.collider_list = [] self.game_over_surface = pygame.image.load('Flappy bird assets/sprites/message.png').convert_alpha() self.game_over_surface = pygame.transform.scale2x(self.game_over_surface) self.game_over_rect = self.game_over_surface.get_rect(center=(288, 512)) self.flap_sound = pygame.mixer.Sound('Flappy bird assets/audio/wing.wav') self.death_sound = pygame.mixer.Sound('Flappy bird assets/audio/die.wav') self.hit_sound = pygame.mixer.Sound('Flappy bird assets/audio/hit.wav') self.point_sound = pygame.mixer.Sound('Flappy bird assets/audio/point.wav') self.rewardMultiplier = 1 self.INCREASEREWARD = pygame.USEREVENT pygame.time.set_timer(self.INCREASEREWARD, 2400) @staticmethod def draw_floor(self): self.screen.blit(self.floor_surface, (self.floor_x_pos, 900)) self.screen.blit(self.floor_surface, (self.floor_x_pos + 576, 900)) @staticmethod def create_pipe(self): random_pipe_pos = random.choice(self.pipe_height) bottomPipe = self.pipe_surface.get_rect(midtop=(700, random_pipe_pos)) topPipe = self.pipe_surface.get_rect(midbottom=(700, random_pipe_pos - 300)) point_collider_rect = self.point_collider.get_rect(midbottom=(700, random_pipe_pos)) return bottomPipe, topPipe, point_collider_rect @staticmethod def move_pipes(self, pipes, colliders): for pipe in pipes: pipe.centerx -= self.speed for collider in colliders: collider.centerx -= self.speed return pipes, colliders @staticmethod def draw_pipes(self, pipes, colliders): for pipe in pipes: if pipe.bottom >= 1024: self.screen.blit(self.pipe_surface, pipe) else: flip_pipe = pygame.transform.flip(self.pipe_surface, False, True) self.screen.blit(flip_pipe, pipe) for collider in colliders: self.screen.blit(self.point_collider, collider) @staticmethod def check_collision(self, pipes): for pipe in pipes: if self.bird_rect.colliderect(pipe): self.hit_sound.play() return False if self.bird_rect.top <= -100 or self.bird_rect.bottom >= 900: self.death_sound.play() return False return True @staticmethod def check_pipe_reached(self, colliders): for collider in colliders: if self.bird_rect.colliderect(collider): colliders.remove(collider) self.pipe_is_spawned = False return colliders, True return colliders, False @staticmethod def rotate_bird(self, bird): new_bird = pygame.transform.rotozoom(bird, -self.bird_movement * 2, 1) return new_bird @staticmethod def bird_animation(self): new_bird = self.bird_frames[self.bird_index] new_bird_rect = new_bird.get_rect(center=(100, self.bird_rect.centery)) return new_bird, new_bird_rect @staticmethod def score_display(self): score_surface = self.game_font.render(str(self.score), True, (255, 255, 255)) score_rect = score_surface.get_rect(center=(288, 100)) self.screen.blit(score_surface, score_rect) @staticmethod def is_done(self): if not self.game_active: self.pipe_is_spawned = False return not self.game_active @staticmethod def get_score(self): return self.score @staticmethod def start_game(self): self.game_active = True self.pipe_list.clear() self.collider_list.clear() self.bird_rect.center = (100, 512) self.bird_movement = 0 self.score = 0 self.speed = 4 self.pipe_is_spawned = False @staticmethod def check_event(self, action): for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() sys.exit() if action == 1: self.bird_movement = 0 self.bird_movement -= 15 self.flap_sound.play() if event.type == self.SPAWNPIPE and self.game_active: bottom_pipe, top_pipe, win_collider = self.create_pipe(self) self.pipe_list.append(bottom_pipe) self.pipe_list.append(top_pipe) self.collider_list.append(win_collider) self.pipe_is_spawned = True if event.type == self.BIRDFLAP: if self.bird_index < 2: self.bird_index += 1 else: self.bird_index = 0 self.bird_surface, self.bird_rect = self.bird_animation(self) if event.type == self.SPEEDUP and self.game_active: if self.speed < 8: self.speed += 0.05 if event.type == self.INCREASEREWARD and self.game_active: self.rewardMultiplier += 0.01 @staticmethod def is_game_active(self): # Bird self.bird_movement += self.gravity rotated_bird = self.rotate_bird(self, self.bird_surface) self.bird_rect.centery += self.bird_movement self.screen.blit(rotated_bird, self.bird_rect) self.game_active = self.check_collision(self, self.pipe_list) collider_list, pipe_reached = self.check_pipe_reached(self, self.collider_list) if pipe_reached: self.point_sound.play() self.score += 1 # Pipes pipe_list, collider_list = self.move_pipes(self, self.pipe_list, self.collider_list) self.draw_pipes(self, pipe_list, collider_list) self.score_display(self) @staticmethod def draw_frame(self): self.screen.blit(self.bg_surface, (0, 0)) if self.game_active: self.is_game_active(self) else: self.screen.blit(self.game_over_surface, self.game_over_rect) # Floor self.floor_x_pos -= 4 self.draw_floor(self) if self.floor_x_pos <= -576: self.floor_x_pos = 0 pygame.display.update() self.clock.tick(60) @staticmethod def get_reward(self): point = 0.1 * self.rewardMultiplier if not self.check_collision: point = -100 * self.rewardMultiplier elif self.check_pipe_reached: point = 50 * self.rewardMultiplier return point @staticmethod def get_state(self): if self.pipe_is_spawned: position = [self.bird_rect.centery, self.collider_list[0].centerx, self.collider_list[0].centery, math.sqrt((self.collider_list[0].centerx - self.bird_rect.centerx) ** 2 + (self.collider_list[0].centery - self.bird_rect.centery) ** 2), self.collider_list[0].centery - self.bird_rect.centery, self.collider_list[0].centerx - self.bird_rect.centerx, self.bird_movement, 900 - self.bird_rect.bottom, 900 - self.collider_list[0].centery] else: position = [self.bird_rect.centery, 0, 0, 0, 0, 0, 900 - self.bird_rect.bottom, 50, self.bird_movement] return position @staticmethod def step(self, action): self.check_event(self, action) self.draw_frame(self) point = self.get_reward(self) self.is_done(self) position = self.get_state(self) return point, position

I'd say try adding another layer to your model architecture first, because it's very small.我会说先尝试在您的 model 架构中添加另一层,因为它非常小。 Also, the observations your receiving from that environment have some very large integers, so you might want to normalize each of those features.此外,您从该环境中接收到的观察值具有一些非常大的整数,因此您可能希望对这些特征中的每一个进行归一化。 Another suggestion that is unrelated is that it is more efficient to implement the replay buffer using a queue mechanism from collections.deque(maxlen) .另一个不相关的建议是使用collections.deque(maxlen)中的队列机制来实现重播缓冲区更有效。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM