简体   繁体   English

PyTorch 分类性能缓慢

[英]Slow performance of PyTorch Categorical

I have been using a PPO (Proximal Policy Optimisation) architecture for training my agent in a custom simulator.我一直在使用 PPO(近端策略优化)架构在自定义模拟器中训练我的代理。 My simulator has become quite fast as it is written in Rust. The speed of my inner loop is therefore bottlenecked by some functions that are inside the PPO agent.我的模拟器已经变得非常快,因为它写在 Rust 中。因此,我的内部循环的速度受到 PPO 代理内部的某些功能的瓶颈。

When I profiled the function with pyinstrument it showed that most of the time is spent on initialising the Categorical class and calculating the log probabilities.当我使用 pyinstrument 分析 function 时,它表明大部分时间都花在初始化分类 class 和计算对数概率上。

I hope someone can help and if there is a faster way to do this using PyTorch.我希望有人可以提供帮助,如果有更快的方法可以使用 PyTorch 来完成此操作。

    def act(self, state):
        action_probs = self.actor(state)
        dist = Categorical(action_probs)

        action = dist.sample()
        action_logprob = dist.log_prob(action)

        return action.detach(), action_logprob.detach()

    def evaluate(self, state, action):
        """Evaluates the action given the state."""
        action_probs = self.actor(state)
        dist = Categorical(action_probs)

        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)

        return action_logprobs, state_values, dist_entropy

Pyinstrument 显示程序的速度。

I have seen some other techniques to do this, but it was not very clear to me if they would inprove the speed.我已经看到一些其他技术可以做到这一点,但我不太清楚它们是否会提高速度。

I ran into same problem a while back and implemented my custom Categorical class by copying from pytorch source code我前一段时间遇到了同样的问题,并通过从pytorch 源代码复制来实现我的自定义Categorical class

It is similar to original code but removes unnecessary functionality.它类似于原始代码,但删除了不必要的功能。 Does not require initializing class every time, instead initialize once and just use set_probs() or set_probs_() for setting new probability values.不需要每次都初始化 class ,而是初始化一次,只需使用set_probs()set_probs_()来设置新的概率值。 Also, it works only with probability values as input (not logits) but we can manually apply softmax on logits anyway.此外,它仅适用于概率值作为输入(不是 logits),但我们无论如何都可以在 logits 上手动应用softmax

import torch
from torch.distributions.utils import probs_to_logits
class Categorical:
    def __init__(self, probs_shape): 
        # NOTE: probs_shape is supposed to be 
        #       the shape of probs that will be 
        #       produced by policy network
        if len(probs_shape) < 1: 
            raise ValueError("`probs_shape` must be at least 1.")
        self.probs_dim = len(probs_shape) 
        self.probs_shape = probs_shape
        self._num_events = probs_shape[-1]
        self._batch_shape = probs_shape[:-1] if self.probs_dim > 1 else torch.Size()
        self._event_shape=torch.Size()

    def set_probs_(self, probs):
        self.probs = probs
        self.logits = probs_to_logits(self.probs)

    def set_probs(self, probs):
        self.probs = probs / probs.sum(-1, keepdim=True) 
        self.logits = probs_to_logits(self.probs)

    def sample(self, sample_shape=torch.Size()):
        if not isinstance(sample_shape, torch.Size):
            sample_shape = torch.Size(sample_shape)
        probs_2d = self.probs.reshape(-1, self._num_events)
        samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T
        return samples_2d.reshape(sample_shape + self._batch_shape + self._event_shape)

    def log_prob(self, value):
        value = value.long().unsqueeze(-1)
        value, log_pmf = torch.broadcast_tensors(value, self.logits)
        value = value[..., :1]
        return log_pmf.gather(-1, value).squeeze(-1)

    def entropy(self):
        min_real = torch.finfo(self.logits.dtype).min
        logits = torch.clamp(self.logits, min=min_real)
        p_log_p = logits * self.probs
        return -p_log_p.sum(-1)


Checking execution time:检查执行时间:

import time
import torch as tt
import torch.distributions as td

First check inbuilt torch.distributions.Categorical首先检查内置torch.distributions.Categorical

start=time.perf_counter()
for _ in range(50000):
    probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
    ct = td.Categorical(probs=probs)
    entropy = ct.entropy()
    action = ct.sample()
    log_prob = ct.log_prob(action)
    entropy, action, log_prob
end=time.perf_counter()
print(end - start)

output: output:

"""
10.024958199996036
"""

Now check custom Categorical现在检查自定义Categorical

start=time.perf_counter()
ct = Categorical((3,4,2)) #<--- initialize class beforehand
for _ in range(50000):
    probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
    ct.set_probs(probs)
    entropy = ct.entropy()
    action = ct.sample()
    log_prob = ct.log_prob(action)
    entropy, action, log_prob
end=time.perf_counter()
print(end - start)

output: output:

"""
4.565093299999717
"""

The execution time dropped by a little more than half.执行时间减少了一半多一点。 It can be further reduced if we use set_probs_() instead of set_probs() .如果我们使用set_probs_()而不是set_probs() ,它可以进一步减少。 There is a subtle difference in set_probs() and set_probs_() which skips the line probs / probs.sum(-1, keepdim=True) which is supposed to remove floating points errors. set_probs()set_probs_()之间存在细微差别,后者跳过了probs / probs.sum(-1, keepdim=True)行,该行应该消除浮点错误。 However, it might not be always necessary.但是,它可能并不总是必要的。

start=time.perf_counter()
ct = Categorical((3,4,2)) #<--- initialize class beforehand
for _ in range(50000):
    probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
    ct.set_probs_(probs)
    entropy = ct.entropy()
    action = ct.sample()
    log_prob = ct.log_prob(action)
    entropy, action, log_prob
end=time.perf_counter()
print(end - start)

output: output:

"""
3.9343119999975897
"""

You can check source code for pytorch distributions module on your machine some where at ..\Lib\site-packages\torch\distributions您可以在您机器上的某个位置检查 pytorch 分发模块的源代码..\Lib\site-packages\torch\distributions

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM