簡體   English   中英

PyTorch 分類性能緩慢

[英]Slow performance of PyTorch Categorical

我一直在使用 PPO(近端策略優化)架構在自定義模擬器中訓練我的代理。 我的模擬器已經變得非常快,因為它寫在 Rust 中。因此,我的內部循環的速度受到 PPO 代理內部的某些功能的瓶頸。

當我使用 pyinstrument 分析 function 時,它表明大部分時間都花在初始化分類 class 和計算對數概率上。

我希望有人可以提供幫助,如果有更快的方法可以使用 PyTorch 來完成此操作。

    def act(self, state):
        action_probs = self.actor(state)
        dist = Categorical(action_probs)

        action = dist.sample()
        action_logprob = dist.log_prob(action)

        return action.detach(), action_logprob.detach()

    def evaluate(self, state, action):
        """Evaluates the action given the state."""
        action_probs = self.actor(state)
        dist = Categorical(action_probs)

        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)

        return action_logprobs, state_values, dist_entropy

Pyinstrument 顯示程序的速度。

我已經看到一些其他技術可以做到這一點,但我不太清楚它們是否會提高速度。

我前一段時間遇到了同樣的問題,並通過從pytorch 源代碼復制來實現我的自定義Categorical class

它類似於原始代碼,但刪除了不必要的功能。 不需要每次都初始化 class ,而是初始化一次,只需使用set_probs()set_probs_()來設置新的概率值。 此外,它僅適用於概率值作為輸入(不是 logits),但我們無論如何都可以在 logits 上手動應用softmax

import torch
from torch.distributions.utils import probs_to_logits
class Categorical:
    def __init__(self, probs_shape): 
        # NOTE: probs_shape is supposed to be 
        #       the shape of probs that will be 
        #       produced by policy network
        if len(probs_shape) < 1: 
            raise ValueError("`probs_shape` must be at least 1.")
        self.probs_dim = len(probs_shape) 
        self.probs_shape = probs_shape
        self._num_events = probs_shape[-1]
        self._batch_shape = probs_shape[:-1] if self.probs_dim > 1 else torch.Size()
        self._event_shape=torch.Size()

    def set_probs_(self, probs):
        self.probs = probs
        self.logits = probs_to_logits(self.probs)

    def set_probs(self, probs):
        self.probs = probs / probs.sum(-1, keepdim=True) 
        self.logits = probs_to_logits(self.probs)

    def sample(self, sample_shape=torch.Size()):
        if not isinstance(sample_shape, torch.Size):
            sample_shape = torch.Size(sample_shape)
        probs_2d = self.probs.reshape(-1, self._num_events)
        samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T
        return samples_2d.reshape(sample_shape + self._batch_shape + self._event_shape)

    def log_prob(self, value):
        value = value.long().unsqueeze(-1)
        value, log_pmf = torch.broadcast_tensors(value, self.logits)
        value = value[..., :1]
        return log_pmf.gather(-1, value).squeeze(-1)

    def entropy(self):
        min_real = torch.finfo(self.logits.dtype).min
        logits = torch.clamp(self.logits, min=min_real)
        p_log_p = logits * self.probs
        return -p_log_p.sum(-1)


檢查執行時間:

import time
import torch as tt
import torch.distributions as td

首先檢查內置torch.distributions.Categorical

start=time.perf_counter()
for _ in range(50000):
    probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
    ct = td.Categorical(probs=probs)
    entropy = ct.entropy()
    action = ct.sample()
    log_prob = ct.log_prob(action)
    entropy, action, log_prob
end=time.perf_counter()
print(end - start)

output:

"""
10.024958199996036
"""

現在檢查自定義Categorical

start=time.perf_counter()
ct = Categorical((3,4,2)) #<--- initialize class beforehand
for _ in range(50000):
    probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
    ct.set_probs(probs)
    entropy = ct.entropy()
    action = ct.sample()
    log_prob = ct.log_prob(action)
    entropy, action, log_prob
end=time.perf_counter()
print(end - start)

output:

"""
4.565093299999717
"""

執行時間減少了一半多一點。 如果我們使用set_probs_()而不是set_probs() ,它可以進一步減少。 set_probs()set_probs_()之間存在細微差別,后者跳過了probs / probs.sum(-1, keepdim=True)行,該行應該消除浮點錯誤。 但是,它可能並不總是必要的。

start=time.perf_counter()
ct = Categorical((3,4,2)) #<--- initialize class beforehand
for _ in range(50000):
    probs = tt.softmax(tt.rand((3,4,2)), dim=-1)
    ct.set_probs_(probs)
    entropy = ct.entropy()
    action = ct.sample()
    log_prob = ct.log_prob(action)
    entropy, action, log_prob
end=time.perf_counter()
print(end - start)

output:

"""
3.9343119999975897
"""

您可以在您機器上的某個位置檢查 pytorch 分發模塊的源代碼..\Lib\site-packages\torch\distributions

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM