簡體   English   中英

PyTorch 中 nn.Linear 的類定義是什么?

[英]What is the class definition of nn.Linear in PyTorch?

我有以下 PyTorch 代碼:

import torch.nn as nn
import torch.nn.functional as F

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(784, 256)
        self.output = nn.Linear(256, 10)
    
    def forward(self, x):
        x = F.sigmoid(self.hidden(x))
        x = F.softmax(self.output(x), dim=1)
    
        return x

我的問題:這個self.hidden是什么?

它從nn.Linear返回並且可以將x作為參數。 self.hidden的目的究竟是什么?

pytorch 中 nn.Linear 的類定義是什么?

文檔


CLASS torch.nn.Linear(in_features, out_features, bias=True)

對傳入數據應用線性變換: y = x*W^T + b

參數:

  • in_features – 每個輸入樣本的大小(即 x 的大小)
  • out_features – 每個輸出樣本的大小(即 y 的大小)
  • 偏差- 如果設置為 False,該層將不會學習附加偏差。 默認值:真

請注意,權重W具有形狀(out_features, in_features)並且偏差b具有形狀(out_features) 它們是隨機初始化的,以后可以更改(例如,在訓練神經網絡期間,它們會通過某些優化算法進行更新)。

在你的神經網絡中, self.hidden = nn.Linear(784, 256)定義了一個隱藏的(意味着它在輸入層和輸出層之間),完全連接的線性層,它采用輸入x的形狀(batch_size, 784) ,其中批量大小是一次傳遞到網絡的輸入數量(每個大小為 784)(作為單個張量),並通過線性方程y = x*W^T + b轉換為張量y形狀(batch_size, 256) 它由 sigmoid 函數進一步轉換, x = F.sigmoid(self.hidden(x)) (它不是nn.Linear的一部分,而是一個附加步驟)。

讓我們看一個具體的例子:

import torch
import torch.nn as nn

x = torch.tensor([[1.0, -1.0],
                  [0.0,  1.0],
                  [0.0,  0.0]])

in_features = x.shape[1]  # = 2
out_features = 2

m = nn.Linear(in_features, out_features)

其中x包含三個輸入(即批量大小為 3)、 x[0]x[1]x[3] ,每個大小為 2,輸出的形狀為(batch size, out_features) = (3, 2)

參數(權重和偏差)的值是:

>>> m.weight
tensor([[-0.4500,  0.5856],
        [-0.1807, -0.4963]])

>>> m.bias
tensor([ 0.2223, -0.6114])

(因為它們是隨機初始化的,很可能你會得到與上面不同的值)

輸出是:

>>> y = m(x)
tensor([[-0.8133, -0.2959],
        [ 0.8079, -1.1077],
        [ 0.2223, -0.6114]])

並且(在幕后)它被計算為:

y = x.matmul(m.weight.t()) + m.bias  # y = x*W^T + b

IE

y[i,j] == x[i,0] * m.weight[j,0] + x[i,1] * m.weight[j,1] + m.bias[j]

其中i在區間[0, batch_size)j[0, out_features)

Network定義為具有兩層,隱藏層和輸出層。 粗略地說,隱藏層的功能是保存可以在訓練過程中優化的參數。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM