简体   繁体   中英

Slow performance of PyTorch Categorical

I have been using a PPO (Proximal Policy Optimisation) architecture for training my agent in a custom simulator. My simulator has become quite fast as it is written in Rust. The speed of my inner loop is therefore bottlenecked by some functions that are inside the PPO agent.

When I profiled the function with pyinstrument it showed that most of the time is spent on initialising the Categorical class and calculating the log probabilities.

I hope someone can help and if there is a faster way to do this using PyTorch.

    def act(self, state):
        action_probs = self.actor(state)
        dist = Categorical(action_probs)

        action = dist.sample()
        action_logprob = dist.log_prob(action)

        return action.detach(), action_logprob.detach()

    def evaluate(self, state, action):
        """Evaluates the action given the state."""
        action_probs = self.actor(state)
        dist = Categorical(action_probs)

        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)

        return action_logprobs, state_values, dist_entropy

Pyinstrument 显示程序的速度。

I have seen some other techniques to do this, but it was not very clear to me if they would inprove the speed.

I ran into same problem a while back and implemented my custom Categorical class by copying from pytorch source code

It is similar to original code but removes unnecessary functionality. Does not require initializing class every time, instead initialize once and just use set_probs() or set_probs_() for setting new probability values. Also, it works only with probability values as input (not logits) but we can manually apply softmax on logits anyway.

import torch
from torch.distributions.utils import probs_to_logits
class Categorical:
    def __init__(self, probs_shape): 
        # NOTE: probs_shape is supposed to be 
        #       the shape of probs that will be 
        #       produced by policy network
        if len(probs_shape) < 1: 
            raise ValueError("`probs_shape` must be at least 1.")
        self.probs_dim = len(probs_shape) 
        self.probs_shape = probs_shape
        self._num_events = probs_shape[-1]
        self._batch_shape = probs_shape[:-1] if self.probs_dim > 1 else torch.Size()
        self._event_shape=torch.Size()

    def set_probs_(self, probs):
        self.probs = probs
        self.logits = probs_to_logits(self.probs)

    def set_probs(self, probs):
        self.probs = probs / probs.sum(-1, keepdim=True) 
        self.logits = probs_to_logits(self.probs)

    def sample(self, sample_shape=torch.Size()):
        if not isinstance(sample_shape, torch.Size):
            sample_shape = torch.Size(sample_shape)
        probs_2d = self.probs.reshape(-1, self._num_events)
        samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T
        return samples_2d.reshape(sample_shape + self._batch_shape + self._event_shape)

    def log_prob(self, value):
        value = value.long().unsqueeze(-1)
        value, log_pmf = torch.broadcast_tensors(value, self.logits)
        value = value[..., :1]
        return log_pmf.gather(-1, value).squeeze(-1)

    def entropy(self):
        min_real = torch.finfo(self.logits.dtype).min
        logits = torch.clamp(self.logits, min=min_real)
        p_log_p = logits * self.probs
        return -p_log_p.sum(-1)


Checking execution time:

import time
import torch as tt
import torch.distributions as td

First check inbuilt torch.distributions.Categorical

start=time.perf_counter()
for _ in range(50000):
    probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
    ct = td.Categorical(probs=probs)
    entropy = ct.entropy()
    action = ct.sample()
    log_prob = ct.log_prob(action)
    entropy, action, log_prob
end=time.perf_counter()
print(end - start)

output:

"""
10.024958199996036
"""

Now check custom Categorical

start=time.perf_counter()
ct = Categorical((3,4,2)) #<--- initialize class beforehand
for _ in range(50000):
    probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
    ct.set_probs(probs)
    entropy = ct.entropy()
    action = ct.sample()
    log_prob = ct.log_prob(action)
    entropy, action, log_prob
end=time.perf_counter()
print(end - start)

output:

"""
4.565093299999717
"""

The execution time dropped by a little more than half. It can be further reduced if we use set_probs_() instead of set_probs() . There is a subtle difference in set_probs() and set_probs_() which skips the line probs / probs.sum(-1, keepdim=True) which is supposed to remove floating points errors. However, it might not be always necessary.

start=time.perf_counter()
ct = Categorical((3,4,2)) #<--- initialize class beforehand
for _ in range(50000):
    probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
    ct.set_probs_(probs)
    entropy = ct.entropy()
    action = ct.sample()
    log_prob = ct.log_prob(action)
    entropy, action, log_prob
end=time.perf_counter()
print(end - start)

output:

"""
3.9343119999975897
"""

You can check source code for pytorch distributions module on your machine some where at ..\Lib\site-packages\torch\distributions

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM