簡體   English   中英

如何在 PyTorch 中返回一個自定義激活 function 的可訓練參數?

[英]How to return of one trainable parameters of custom activation function in PyTorch?

import torch
from another_file import Custom_activation_function as caf
class Example(torch.nn.Module):
   def __init__(self):
      super().__init__()
      self.model = torch.nn.Sequential(
         torch.nn.Linear(4, 250, bias=False),
         caf(),
         torch.nn.Linear(250, 250, bias=False),
         caf(),
         torch.nn.Linear(250, 1, bias=False),
      )
   def forward(self, x):
      return model(x)
x = torch.randn((100, 4))
y = model(x)

下面提供了 Custom_activation_function 的偽代碼:

class Custom_activation_function(torch.nn.Module):
   def forward(self, x):
     y = x * x
     b = torch.exp(x)
     return y

在上面的兩段代碼中,Custom_activation_function 是一個自定義激活 function 的 class,多個線性層和激活函數在一個順序的 Z20F35E630DAF44DBFA4C3F68F5399D8Z 中捆綁在一起。 在訓練 model 時,我們希望訪問 Custom_activation_function 中定義的 b(例如,在損失 function 中的正則化項包括 b 的情況下)。 最好在不改變順序 model 的情況下,在訓練期間如何獲得 b?

如果您需要訪問b作為中間 state (如在您的偽代碼中),您可以嘗試以下操作:

import torch


class caf(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        y = x * x
        self.b = torch.exp(x)
        return y


class Example(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.model = torch.nn.Sequential(
            torch.nn.Linear(4, 250, bias=False),
            caf(),
            torch.nn.Linear(250, 250, bias=False),
            caf(),
            torch.nn.Linear(250, 1, bias=False),
        )

    def forward(self, x):
        return self.model(x)


my_model = Example()
optimizer = torch.optim.SGD(my_model.parameters(), lr=0.01)
loss_func = torch.nn.MSELoss()

x = torch.randn((100, 4))
y = torch.randn((100, 1))

for iteration in range(1000):
    my_model.zero_grad()
    out = my_model(x)

    loss_base = loss_func(out, y)
    loss_reg = []

    # Identify your b's
    for mod in my_model.model.modules():
        if isinstance(mod, caf):
            loss_reg.append(mod.b)

    loss_reg = torch.stack(loss_reg).mean()
    loss = loss_base + loss_reg
    loss.backward()
    optimizer.step()
    print(loss.item(), loss_base.item(), loss_reg.item())

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM