[英]Pytorch example from Deep Reinforcement learning in action running too slow
我正在运行《深度强化学习实战》一书中的一个示例; 作者:Alexander Zai,Brandon Brown,但我的代码运行速度极慢,我终其一生都无法理解为什么。
我可以从我的例子和它的例子中发现的唯一区别是我使用的是不同的游戏。 但是,慢动作生成不应该是问题,因为我可以使用表格方法使用相同的游戏,并在短短几分钟内轻松地在 500000 个样本上训练它。 但是一旦涉及神经网络,仅在 50000 上进行训练需要几个小时以上。
我正在使用体验回放和目标网络。
我的代码如下:
def train(nGames, lGames, batchSize, replaySize, learning_rate, syncFreq):
epsilon = .1 # the initial exploration rate
gamma = .9 # the discount of the rewards
replay = deque(maxlen=replaySize) # initialize the experience replay buffer
loss_fn = torch.nn.MSELoss()
# This qFunction is defined further into the question
optimizer = torch.optim.Adam(qFunction.parameters(), lr=learning_rate)
j = 0
losses = []
# Actions that we actually send to the game
actionIndexed = {0: "U", 1: "D", 2: "L", 3: "R", 4: "PASS"}
for e in range(nGames):
progressBar(e, nGames) # Just a progress bar function. Not relevant
game = DayGathering() # This is my own game
state = game.observation() # This just extracts a dictionary with relevant information about the game
# start playing the game
i = 0
done = False
while not done:
# agent taking an action with epsilon-greedy strategy
if np.random.random() < epsilon:
actionIndex = np.random.choice(range(5))
action = actionIndexed[actionIndex]
else:
# Select the action with the largest Q-value
stateC = converter(state)
qValues = qFunction(stateC).data.numpy()
actionIndex = np.argmax(qValues)
action = actionIndexed[actionIndex]
# advancing the state of the board
# This functions just move to the next state
game.playerTurn(action)
game.boardTurn()
newState = game.observation()
# checking for rewards
reward = rewardFunction(newState)
# check if the current game is terminal. This happens if i = lGame
# - 1 or if the agent is caught outside during night fall
if i == lGames - 1 or outsideAtNight(newState):
done = True
else:
done = False
# saving the (state, action, newstate, reward) values in the experience replay buffer
replay.append((stateC, actionIndex, newState, reward, done))
# when we get to the batch size we train the network
if len(replay) > batchSize:
miniBatch = random.sample(replay, batchSize)
stateBatch = torch.cat([sC for sC, _, _, _, _ in miniBatch])
actionBatch = torch.Tensor([actionIndex for _, actionIndex, _, _, _ in miniBatch])
rewardBatch = torch.Tensor([reward for _, _, _, reward, _ in miniBatch])
newStateBatch = torch.cat([converter(newState) for _, _, newState, _, _ in miniBatch])
doneBatch = torch.Tensor([done for _, _, _, _, done in miniBatch])
# use the target network to bootstrap
with torch.no_grad():
newStateQ = qFunctionTarget(newStateBatch)
# compute the predition of the networks for R + gamma*MaxRewards
Y = rewardBatch + gamma * (1 - doneBatch)*torch.max(newStateQ,dim=1)[0]
# Compute the discounted rewards using the first original network
stateQ = qFunction(stateBatch).gather(dim=1, index=actionBatch.long().unsqueeze(dim=1)).squeeze()
# compute the loss of the model and backpropagate
loss = loss_fn(stateQ, Y.detach())
optimizer.zero_grad()
loss.backward()
# keeping track of the losses
losses.append(loss.item())
optimizer.step()
# periodically copy parameters to the target network
if j % syncFreq == 0:
qFunctionTarget.load_state_dict(qFunction.state_dict())
# advance the state of the game
state = newState
i += 1
j += 1 # advance j
return losses
那是主要的训练循环。 现在对于之前代码中出现的其他内容
神经网络
qFunction = torch.nn.Sequential(torch.nn.Linear(426, 1024),
torch.nn.ReLU(),
torch.nn.Linear(1024, 512),
torch.nn.ReLU(),
torch.nn.Linear(512, 5),
)
qFunctionTarget = copy.deepcopy(qFunction)
qFunctionTarget.load_state_dict(qFunction.state_dict())
从观察到 ann 输入的转换器:我认为这不是很相关,但我决定以任何方式包含它。 基本上,游戏是一个网格世界,我们只需将观察结果转换为一组 0 和 1 表示棋盘上的位置。 创建这个板后,我们用噪音干扰它。
def converter(observation):
px, py = observation["player position"]
playerBitBoard = torch.zeros([9, 9])
playerBitBoard[py, px] = 1
# making the bit boards for all apples
a1x, a1y = observation["apple 1"]
apple1BitBoard = torch.zeros([9, 9])
apple1BitBoard[a1y, a1x] = 1
a2x, a2y = observation["apple 2"]
apple2BitBoard = torch.zeros([9, 9])
apple2BitBoard[a2y, a2x] = 1
a3x, a3y = observation["apple 3"]
apple3BitBoard = torch.zeros([9, 9])
apple3BitBoard[a3y, a3x] = 1
# making the bit board for the key. Note that the key will have NONE values
# in its position. They must be dealt with. Current solution put all values
# in the key board equal to zero
key = observation["key position"]
keyBitBoard = torch.zeros([9, 9])
if key:
kx, ky = key
keyBitBoard[ky, kx] = 1
# making the time features
t = observation["time until nigth"]
timeBitBoard = torch.zeros(21)
timeBitBoard[t] = 1
# now we join all in a single array
return torch.cat([playerBitBoard.reshape(81),
apple1BitBoard.reshape(81),
apple2BitBoard.reshape(81),
apple3BitBoard.reshape(81),
keyBitBoard.reshape(81),
timeBitBoard]).reshape((1, 426)) + torch.rand((1, 426))/100
那个转换器的东西看起来很慢,也许你可以矢量化它。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.