繁体   English   中英

PyTorch 分类性能缓慢

[英]Slow performance of PyTorch Categorical

我一直在使用 PPO(近端策略优化)架构在自定义模拟器中训练我的代理。 我的模拟器已经变得非常快,因为它写在 Rust 中。因此,我的内部循环的速度受到 PPO 代理内部的某些功能的瓶颈。

当我使用 pyinstrument 分析 function 时,它表明大部分时间都花在初始化分类 class 和计算对数概率上。

我希望有人可以提供帮助,如果有更快的方法可以使用 PyTorch 来完成此操作。

    def act(self, state):
        action_probs = self.actor(state)
        dist = Categorical(action_probs)

        action = dist.sample()
        action_logprob = dist.log_prob(action)

        return action.detach(), action_logprob.detach()

    def evaluate(self, state, action):
        """Evaluates the action given the state."""
        action_probs = self.actor(state)
        dist = Categorical(action_probs)

        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)

        return action_logprobs, state_values, dist_entropy

Pyinstrument 显示程序的速度。

我已经看到一些其他技术可以做到这一点,但我不太清楚它们是否会提高速度。

我前一段时间遇到了同样的问题,并通过从pytorch 源代码复制来实现我的自定义Categorical class

它类似于原始代码,但删除了不必要的功能。 不需要每次都初始化 class ,而是初始化一次,只需使用set_probs()set_probs_()来设置新的概率值。 此外,它仅适用于概率值作为输入(不是 logits),但我们无论如何都可以在 logits 上手动应用softmax

import torch
from torch.distributions.utils import probs_to_logits
class Categorical:
    def __init__(self, probs_shape): 
        # NOTE: probs_shape is supposed to be 
        #       the shape of probs that will be 
        #       produced by policy network
        if len(probs_shape) < 1: 
            raise ValueError("`probs_shape` must be at least 1.")
        self.probs_dim = len(probs_shape) 
        self.probs_shape = probs_shape
        self._num_events = probs_shape[-1]
        self._batch_shape = probs_shape[:-1] if self.probs_dim > 1 else torch.Size()
        self._event_shape=torch.Size()

    def set_probs_(self, probs):
        self.probs = probs
        self.logits = probs_to_logits(self.probs)

    def set_probs(self, probs):
        self.probs = probs / probs.sum(-1, keepdim=True) 
        self.logits = probs_to_logits(self.probs)

    def sample(self, sample_shape=torch.Size()):
        if not isinstance(sample_shape, torch.Size):
            sample_shape = torch.Size(sample_shape)
        probs_2d = self.probs.reshape(-1, self._num_events)
        samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T
        return samples_2d.reshape(sample_shape + self._batch_shape + self._event_shape)

    def log_prob(self, value):
        value = value.long().unsqueeze(-1)
        value, log_pmf = torch.broadcast_tensors(value, self.logits)
        value = value[..., :1]
        return log_pmf.gather(-1, value).squeeze(-1)

    def entropy(self):
        min_real = torch.finfo(self.logits.dtype).min
        logits = torch.clamp(self.logits, min=min_real)
        p_log_p = logits * self.probs
        return -p_log_p.sum(-1)


检查执行时间:

import time
import torch as tt
import torch.distributions as td

首先检查内置torch.distributions.Categorical

start=time.perf_counter()
for _ in range(50000):
    probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
    ct = td.Categorical(probs=probs)
    entropy = ct.entropy()
    action = ct.sample()
    log_prob = ct.log_prob(action)
    entropy, action, log_prob
end=time.perf_counter()
print(end - start)

output:

"""
10.024958199996036
"""

现在检查自定义Categorical

start=time.perf_counter()
ct = Categorical((3,4,2)) #<--- initialize class beforehand
for _ in range(50000):
    probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
    ct.set_probs(probs)
    entropy = ct.entropy()
    action = ct.sample()
    log_prob = ct.log_prob(action)
    entropy, action, log_prob
end=time.perf_counter()
print(end - start)

output:

"""
4.565093299999717
"""

执行时间减少了一半多一点。 如果我们使用set_probs_()而不是set_probs() ,它可以进一步减少。 set_probs()set_probs_()之间存在细微差别,后者跳过了probs / probs.sum(-1, keepdim=True)行,该行应该消除浮点错误。 但是,它可能并不总是必要的。

start=time.perf_counter()
ct = Categorical((3,4,2)) #<--- initialize class beforehand
for _ in range(50000):
    probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
    ct.set_probs_(probs)
    entropy = ct.entropy()
    action = ct.sample()
    log_prob = ct.log_prob(action)
    entropy, action, log_prob
end=time.perf_counter()
print(end - start)

output:

"""
3.9343119999975897
"""

您可以在您机器上的某个位置检查 pytorch 分发模块的源代码..\Lib\site-packages\torch\distributions

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM