[英]Slow performance of PyTorch Categorical
我一直在使用 PPO(近端策略優化)架構在自定義模擬器中訓練我的代理。 我的模擬器已經變得非常快,因為它寫在 Rust 中。因此,我的內部循環的速度受到 PPO 代理內部的某些功能的瓶頸。
當我使用 pyinstrument 分析 function 時,它表明大部分時間都花在初始化分類 class 和計算對數概率上。
我希望有人可以提供幫助,如果有更快的方法可以使用 PyTorch 來完成此操作。
def act(self, state):
action_probs = self.actor(state)
dist = Categorical(action_probs)
action = dist.sample()
action_logprob = dist.log_prob(action)
return action.detach(), action_logprob.detach()
def evaluate(self, state, action):
"""Evaluates the action given the state."""
action_probs = self.actor(state)
dist = Categorical(action_probs)
action_logprobs = dist.log_prob(action)
dist_entropy = dist.entropy()
state_values = self.critic(state)
return action_logprobs, state_values, dist_entropy
我已經看到一些其他技術可以做到這一點,但我不太清楚它們是否會提高速度。
我前一段時間遇到了同樣的問題,並通過從pytorch 源代碼復制來實現我的自定義Categorical
class
它類似於原始代碼,但刪除了不必要的功能。 不需要每次都初始化 class ,而是初始化一次,只需使用set_probs()
或set_probs_()
來設置新的概率值。 此外,它僅適用於概率值作為輸入(不是 logits),但我們無論如何都可以在 logits 上手動應用softmax
。
import torch
from torch.distributions.utils import probs_to_logits
class Categorical:
def __init__(self, probs_shape):
# NOTE: probs_shape is supposed to be
# the shape of probs that will be
# produced by policy network
if len(probs_shape) < 1:
raise ValueError("`probs_shape` must be at least 1.")
self.probs_dim = len(probs_shape)
self.probs_shape = probs_shape
self._num_events = probs_shape[-1]
self._batch_shape = probs_shape[:-1] if self.probs_dim > 1 else torch.Size()
self._event_shape=torch.Size()
def set_probs_(self, probs):
self.probs = probs
self.logits = probs_to_logits(self.probs)
def set_probs(self, probs):
self.probs = probs / probs.sum(-1, keepdim=True)
self.logits = probs_to_logits(self.probs)
def sample(self, sample_shape=torch.Size()):
if not isinstance(sample_shape, torch.Size):
sample_shape = torch.Size(sample_shape)
probs_2d = self.probs.reshape(-1, self._num_events)
samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T
return samples_2d.reshape(sample_shape + self._batch_shape + self._event_shape)
def log_prob(self, value):
value = value.long().unsqueeze(-1)
value, log_pmf = torch.broadcast_tensors(value, self.logits)
value = value[..., :1]
return log_pmf.gather(-1, value).squeeze(-1)
def entropy(self):
min_real = torch.finfo(self.logits.dtype).min
logits = torch.clamp(self.logits, min=min_real)
p_log_p = logits * self.probs
return -p_log_p.sum(-1)
檢查執行時間:
import time
import torch as tt
import torch.distributions as td
首先檢查內置torch.distributions.Categorical
start=time.perf_counter()
for _ in range(50000):
probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
ct = td.Categorical(probs=probs)
entropy = ct.entropy()
action = ct.sample()
log_prob = ct.log_prob(action)
entropy, action, log_prob
end=time.perf_counter()
print(end - start)
output:
"""
10.024958199996036
"""
現在檢查自定義Categorical
start=time.perf_counter()
ct = Categorical((3,4,2)) #<--- initialize class beforehand
for _ in range(50000):
probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
ct.set_probs(probs)
entropy = ct.entropy()
action = ct.sample()
log_prob = ct.log_prob(action)
entropy, action, log_prob
end=time.perf_counter()
print(end - start)
output:
"""
4.565093299999717
"""
執行時間減少了一半多一點。 如果我們使用set_probs_()
而不是set_probs()
,它可以進一步減少。 set_probs()
和set_probs_()
之間存在細微差別,后者跳過了probs / probs.sum(-1, keepdim=True)
行,該行應該消除浮點錯誤。 但是,它可能並不總是必要的。
start=time.perf_counter()
ct = Categorical((3,4,2)) #<--- initialize class beforehand
for _ in range(50000):
probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
ct.set_probs_(probs)
entropy = ct.entropy()
action = ct.sample()
log_prob = ct.log_prob(action)
entropy, action, log_prob
end=time.perf_counter()
print(end - start)
output:
"""
3.9343119999975897
"""
您可以在您機器上的某個位置檢查 pytorch 分發模塊的源代碼..\Lib\site-packages\torch\distributions
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.