简体   繁体   中英

Pytorch example from Deep Reinforcement learning in action running too slow

I am running an example from the book Deep reinforcement learning in action; by Alexander Zai, Brandon Brown but my code runs extremely slowly and for the life of me I cannot understand why.

The only difference I can find from my example to its example is the fact that I am using a different game; however, slow move generations should not be the problem as I can use the same game with tabular methods and easily train it on 500000 samples in just a few minutes. But as soon as neural networks are involved Training only on 50000 takes more than a couple hours.

I am using both experience replay and a target network.

My code is as follows:

def train(nGames, lGames, batchSize, replaySize, learning_rate, syncFreq):
    epsilon = .1 # the initial exploration rate
    gamma   = .9 # the discount of the rewards

    replay    = deque(maxlen=replaySize) # initialize the experience replay buffer

    loss_fn = torch.nn.MSELoss()
    # This qFunction is defined further into the question
    optimizer = torch.optim.Adam(qFunction.parameters(), lr=learning_rate)

    j        = 0
    losses   = []

    # Actions that we actually send to the game
    actionIndexed = {0: "U", 1: "D", 2: "L", 3: "R", 4: "PASS"}

    for e in range(nGames):

        progressBar(e, nGames) # Just a progress bar function. Not relevant

        game  = DayGathering() # This is my own game
        state = game.observation() # This just extracts a dictionary with relevant information about the game

        # start playing the game
        i = 0
        done = False
        while not done:
            # agent taking an action with epsilon-greedy strategy
            if np.random.random() < epsilon:
                actionIndex = np.random.choice(range(5))
                action      = actionIndexed[actionIndex]
            else:
                # Select the action with the largest Q-value
                stateC      = converter(state)
                qValues     = qFunction(stateC).data.numpy()
                actionIndex = np.argmax(qValues)
                action      = actionIndexed[actionIndex]

            # advancing the state of the board
            # This functions just move to the next state
            game.playerTurn(action)
            game.boardTurn()
            newState = game.observation()

            # checking for rewards
            reward = rewardFunction(newState)

            # check if the current game is terminal. This happens if i = lGame
            # - 1 or if the agent is caught outside during night fall
            if i == lGames - 1 or outsideAtNight(newState):
                done = True
            else:
                done = False

            # saving the (state, action, newstate, reward) values in the experience replay buffer
            replay.append((stateC, actionIndex, newState, reward, done))

            # when we get to the batch size we train the network
            if len(replay) > batchSize:
                miniBatch     = random.sample(replay, batchSize)
                stateBatch    = torch.cat([sC for sC, _, _, _, _ in miniBatch])
                actionBatch   = torch.Tensor([actionIndex for _, actionIndex, _, _, _ in miniBatch])
                rewardBatch   = torch.Tensor([reward for _, _, _, reward, _ in miniBatch])
                newStateBatch = torch.cat([converter(newState) for _, _, newState, _, _ in miniBatch])
                doneBatch     = torch.Tensor([done for _, _, _, _, done in miniBatch])

                # use the target network to bootstrap
                with torch.no_grad():
                    newStateQ = qFunctionTarget(newStateBatch)

                # compute the predition of the networks for R + gamma*MaxRewards
                Y = rewardBatch + gamma * (1 - doneBatch)*torch.max(newStateQ,dim=1)[0]

                # Compute the discounted rewards using the first original network
                stateQ = qFunction(stateBatch).gather(dim=1, index=actionBatch.long().unsqueeze(dim=1)).squeeze()

                # compute the loss of the model and backpropagate
                loss = loss_fn(stateQ, Y.detach())
                optimizer.zero_grad()
                loss.backward()
                # keeping track of the losses
                losses.append(loss.item())
                optimizer.step()

                # periodically copy parameters to the target network
                if j % syncFreq == 0:
                    qFunctionTarget.load_state_dict(qFunction.state_dict())

            # advance the state of the game
            state = newState

            i += 1
            j += 1 # advance j
    return losses

That was the main training loop. Now for other things that appear in the previous code

The neural network

qFunction = torch.nn.Sequential(torch.nn.Linear(426, 1024),
                                torch.nn.ReLU(),
                                torch.nn.Linear(1024, 512),
                                torch.nn.ReLU(),
                                torch.nn.Linear(512, 5),
                                )
qFunctionTarget = copy.deepcopy(qFunction)
qFunctionTarget.load_state_dict(qFunction.state_dict())

The converter from observations to the ann input : I don't think this is very relevant bu I decided to include it any way. Basically the game is a grid world and we simply convert the observation to an array of zeros and ones indicating positions on the board. After creating this board we perturb it with noise.

def converter(observation):
    px, py = observation["player position"]
    playerBitBoard = torch.zeros([9, 9])
    playerBitBoard[py, px] = 1

    # making the bit boards for all apples
    a1x, a1y = observation["apple 1"]
    apple1BitBoard = torch.zeros([9, 9])
    apple1BitBoard[a1y, a1x] = 1

    a2x, a2y = observation["apple 2"]
    apple2BitBoard = torch.zeros([9, 9])
    apple2BitBoard[a2y, a2x] = 1

    a3x, a3y = observation["apple 3"]
    apple3BitBoard = torch.zeros([9, 9])
    apple3BitBoard[a3y, a3x] = 1

    # making the bit board for the key. Note that the key will have NONE values
    # in its position. They must be dealt with. Current solution put all values
    # in the key board equal to zero
    key = observation["key position"]
    keyBitBoard = torch.zeros([9, 9])
    if key:
        kx, ky = key
        keyBitBoard[ky, kx] = 1

    # making the time features
    t = observation["time until nigth"]
    timeBitBoard = torch.zeros(21)
    timeBitBoard[t] = 1

    # now we join all in a single array
    return torch.cat([playerBitBoard.reshape(81),
                      apple1BitBoard.reshape(81),
                      apple2BitBoard.reshape(81),
                      apple3BitBoard.reshape(81),
                      keyBitBoard.reshape(81),
                      timeBitBoard]).reshape((1, 426)) + torch.rand((1, 426))/100

That converter thing looks slow, maybe you can vectorize it.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM