[英]Slow performance of PyTorch Categorical
我一直在使用 PPO(近端策略优化)架构在自定义模拟器中训练我的代理。 我的模拟器已经变得非常快,因为它写在 Rust 中。因此,我的内部循环的速度受到 PPO 代理内部的某些功能的瓶颈。
当我使用 pyinstrument 分析 function 时,它表明大部分时间都花在初始化分类 class 和计算对数概率上。
我希望有人可以提供帮助,如果有更快的方法可以使用 PyTorch 来完成此操作。
def act(self, state):
action_probs = self.actor(state)
dist = Categorical(action_probs)
action = dist.sample()
action_logprob = dist.log_prob(action)
return action.detach(), action_logprob.detach()
def evaluate(self, state, action):
"""Evaluates the action given the state."""
action_probs = self.actor(state)
dist = Categorical(action_probs)
action_logprobs = dist.log_prob(action)
dist_entropy = dist.entropy()
state_values = self.critic(state)
return action_logprobs, state_values, dist_entropy
我已经看到一些其他技术可以做到这一点,但我不太清楚它们是否会提高速度。
我前一段时间遇到了同样的问题,并通过从pytorch 源代码复制来实现我的自定义Categorical
class
它类似于原始代码,但删除了不必要的功能。 不需要每次都初始化 class ,而是初始化一次,只需使用set_probs()
或set_probs_()
来设置新的概率值。 此外,它仅适用于概率值作为输入(不是 logits),但我们无论如何都可以在 logits 上手动应用softmax
。
import torch
from torch.distributions.utils import probs_to_logits
class Categorical:
def __init__(self, probs_shape):
# NOTE: probs_shape is supposed to be
# the shape of probs that will be
# produced by policy network
if len(probs_shape) < 1:
raise ValueError("`probs_shape` must be at least 1.")
self.probs_dim = len(probs_shape)
self.probs_shape = probs_shape
self._num_events = probs_shape[-1]
self._batch_shape = probs_shape[:-1] if self.probs_dim > 1 else torch.Size()
self._event_shape=torch.Size()
def set_probs_(self, probs):
self.probs = probs
self.logits = probs_to_logits(self.probs)
def set_probs(self, probs):
self.probs = probs / probs.sum(-1, keepdim=True)
self.logits = probs_to_logits(self.probs)
def sample(self, sample_shape=torch.Size()):
if not isinstance(sample_shape, torch.Size):
sample_shape = torch.Size(sample_shape)
probs_2d = self.probs.reshape(-1, self._num_events)
samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T
return samples_2d.reshape(sample_shape + self._batch_shape + self._event_shape)
def log_prob(self, value):
value = value.long().unsqueeze(-1)
value, log_pmf = torch.broadcast_tensors(value, self.logits)
value = value[..., :1]
return log_pmf.gather(-1, value).squeeze(-1)
def entropy(self):
min_real = torch.finfo(self.logits.dtype).min
logits = torch.clamp(self.logits, min=min_real)
p_log_p = logits * self.probs
return -p_log_p.sum(-1)
检查执行时间:
import time
import torch as tt
import torch.distributions as td
首先检查内置torch.distributions.Categorical
start=time.perf_counter()
for _ in range(50000):
probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
ct = td.Categorical(probs=probs)
entropy = ct.entropy()
action = ct.sample()
log_prob = ct.log_prob(action)
entropy, action, log_prob
end=time.perf_counter()
print(end - start)
output:
"""
10.024958199996036
"""
现在检查自定义Categorical
start=time.perf_counter()
ct = Categorical((3,4,2)) #<--- initialize class beforehand
for _ in range(50000):
probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
ct.set_probs(probs)
entropy = ct.entropy()
action = ct.sample()
log_prob = ct.log_prob(action)
entropy, action, log_prob
end=time.perf_counter()
print(end - start)
output:
"""
4.565093299999717
"""
执行时间减少了一半多一点。 如果我们使用set_probs_()
而不是set_probs()
,它可以进一步减少。 set_probs()
和set_probs_()
之间存在细微差别,后者跳过了probs / probs.sum(-1, keepdim=True)
行,该行应该消除浮点错误。 但是,它可能并不总是必要的。
start=time.perf_counter()
ct = Categorical((3,4,2)) #<--- initialize class beforehand
for _ in range(50000):
probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
ct.set_probs_(probs)
entropy = ct.entropy()
action = ct.sample()
log_prob = ct.log_prob(action)
entropy, action, log_prob
end=time.perf_counter()
print(end - start)
output:
"""
3.9343119999975897
"""
您可以在您机器上的某个位置检查 pytorch 分发模块的源代码..\Lib\site-packages\torch\distributions
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.