繁体   English   中英

Pytorch 示例来自 Deep Reinforcement learning in action 运行速度太慢

[英]Pytorch example from Deep Reinforcement learning in action running too slow

我正在运行《深度强化学习实战》一书中的一个示例; 作者:Alexander Zai,Brandon Brown,但我的代码运行速度极慢,我终其一生都无法理解为什么。

我可以从我的例子和它的例子中发现的唯一区别是我使用的是不同的游戏。 但是,慢动作生成不应该是问题,因为我可以使用表格方法使用相同的游戏,并在短短几分钟内轻松地在 500000 个样本上训练它。 但是一旦涉及神经网络,仅在 50000 上进行训练需要几个小时以上。

我正在使用体验回放和目标网络。

我的代码如下:

def train(nGames, lGames, batchSize, replaySize, learning_rate, syncFreq):
    epsilon = .1 # the initial exploration rate
    gamma   = .9 # the discount of the rewards

    replay    = deque(maxlen=replaySize) # initialize the experience replay buffer

    loss_fn = torch.nn.MSELoss()
    # This qFunction is defined further into the question
    optimizer = torch.optim.Adam(qFunction.parameters(), lr=learning_rate)

    j        = 0
    losses   = []

    # Actions that we actually send to the game
    actionIndexed = {0: "U", 1: "D", 2: "L", 3: "R", 4: "PASS"}

    for e in range(nGames):

        progressBar(e, nGames) # Just a progress bar function. Not relevant

        game  = DayGathering() # This is my own game
        state = game.observation() # This just extracts a dictionary with relevant information about the game

        # start playing the game
        i = 0
        done = False
        while not done:
            # agent taking an action with epsilon-greedy strategy
            if np.random.random() < epsilon:
                actionIndex = np.random.choice(range(5))
                action      = actionIndexed[actionIndex]
            else:
                # Select the action with the largest Q-value
                stateC      = converter(state)
                qValues     = qFunction(stateC).data.numpy()
                actionIndex = np.argmax(qValues)
                action      = actionIndexed[actionIndex]

            # advancing the state of the board
            # This functions just move to the next state
            game.playerTurn(action)
            game.boardTurn()
            newState = game.observation()

            # checking for rewards
            reward = rewardFunction(newState)

            # check if the current game is terminal. This happens if i = lGame
            # - 1 or if the agent is caught outside during night fall
            if i == lGames - 1 or outsideAtNight(newState):
                done = True
            else:
                done = False

            # saving the (state, action, newstate, reward) values in the experience replay buffer
            replay.append((stateC, actionIndex, newState, reward, done))

            # when we get to the batch size we train the network
            if len(replay) > batchSize:
                miniBatch     = random.sample(replay, batchSize)
                stateBatch    = torch.cat([sC for sC, _, _, _, _ in miniBatch])
                actionBatch   = torch.Tensor([actionIndex for _, actionIndex, _, _, _ in miniBatch])
                rewardBatch   = torch.Tensor([reward for _, _, _, reward, _ in miniBatch])
                newStateBatch = torch.cat([converter(newState) for _, _, newState, _, _ in miniBatch])
                doneBatch     = torch.Tensor([done for _, _, _, _, done in miniBatch])

                # use the target network to bootstrap
                with torch.no_grad():
                    newStateQ = qFunctionTarget(newStateBatch)

                # compute the predition of the networks for R + gamma*MaxRewards
                Y = rewardBatch + gamma * (1 - doneBatch)*torch.max(newStateQ,dim=1)[0]

                # Compute the discounted rewards using the first original network
                stateQ = qFunction(stateBatch).gather(dim=1, index=actionBatch.long().unsqueeze(dim=1)).squeeze()

                # compute the loss of the model and backpropagate
                loss = loss_fn(stateQ, Y.detach())
                optimizer.zero_grad()
                loss.backward()
                # keeping track of the losses
                losses.append(loss.item())
                optimizer.step()

                # periodically copy parameters to the target network
                if j % syncFreq == 0:
                    qFunctionTarget.load_state_dict(qFunction.state_dict())

            # advance the state of the game
            state = newState

            i += 1
            j += 1 # advance j
    return losses

那是主要的训练循环。 现在对于之前代码中出现的其他内容

神经网络

qFunction = torch.nn.Sequential(torch.nn.Linear(426, 1024),
                                torch.nn.ReLU(),
                                torch.nn.Linear(1024, 512),
                                torch.nn.ReLU(),
                                torch.nn.Linear(512, 5),
                                )
qFunctionTarget = copy.deepcopy(qFunction)
qFunctionTarget.load_state_dict(qFunction.state_dict())

从观察到 ann 输入的转换器:我认为这不是很相关,但我决定以任何方式包含它。 基本上,游戏是一个网格世界,我们只需将观察结果转换为一组 0 和 1 表示棋盘上的位置。 创建这个板后,我们用噪音干扰它。

def converter(observation):
    px, py = observation["player position"]
    playerBitBoard = torch.zeros([9, 9])
    playerBitBoard[py, px] = 1

    # making the bit boards for all apples
    a1x, a1y = observation["apple 1"]
    apple1BitBoard = torch.zeros([9, 9])
    apple1BitBoard[a1y, a1x] = 1

    a2x, a2y = observation["apple 2"]
    apple2BitBoard = torch.zeros([9, 9])
    apple2BitBoard[a2y, a2x] = 1

    a3x, a3y = observation["apple 3"]
    apple3BitBoard = torch.zeros([9, 9])
    apple3BitBoard[a3y, a3x] = 1

    # making the bit board for the key. Note that the key will have NONE values
    # in its position. They must be dealt with. Current solution put all values
    # in the key board equal to zero
    key = observation["key position"]
    keyBitBoard = torch.zeros([9, 9])
    if key:
        kx, ky = key
        keyBitBoard[ky, kx] = 1

    # making the time features
    t = observation["time until nigth"]
    timeBitBoard = torch.zeros(21)
    timeBitBoard[t] = 1

    # now we join all in a single array
    return torch.cat([playerBitBoard.reshape(81),
                      apple1BitBoard.reshape(81),
                      apple2BitBoard.reshape(81),
                      apple3BitBoard.reshape(81),
                      keyBitBoard.reshape(81),
                      timeBitBoard]).reshape((1, 426)) + torch.rand((1, 426))/100

那个转换器的东西看起来很慢,也许你可以矢量化它。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM