簡體   English   中英

神經網絡中的權重

[英]Weights in Neural Network

我正在閱讀: https://towardsdatascience.com/how-to-build-your-own-neural-network-from-scratch-in-python-68998a08e4f6

我看到以下代碼:

import numpy as np

def sigmoid(x):
    return 1.0/(1+ np.exp(-x))

def sigmoid_derivative(x):
    return x * (1.0 - x)

class NeuralNetwork:
    def __init__(self, x, y):
        self.input      = x
        self.weights1   = np.random.rand(self.input.shape[1],4) 
        self.weights2   = np.random.rand(4,1)                 
        self.y          = y
        self.output     = np.zeros(self.y.shape)

    def feedforward(self):
        self.layer1 = sigmoid(np.dot(self.input, self.weights1))
        self.output = sigmoid(np.dot(self.layer1, self.weights2))

    def backprop(self):
        # application of the chain rule to find derivative of the loss function with respect to weights2 and weights1
        d_weights2 = np.dot(self.layer1.T, (2*(self.y - self.output) * sigmoid_derivative(self.output)))
        d_weights1 = np.dot(self.input.T,  (np.dot(2*(self.y - self.output) * sigmoid_derivative(self.output), self.weights2.T) * sigmoid_derivative(self.layer1)))

        # update the weights with the derivative (slope) of the loss function
        self.weights1 += d_weights1
        self.weights2 += d_weights2


if __name__ == "__main__":
    X = np.array([[0,0,1],
                  [0,1,1],
                  [1,0,1],
                  [1,1,1]])
    y = np.array([[0],[1],[1],[0]])
    nn = NeuralNetwork(X,y)

    for i in range(1500):
        nn.feedforward()
        nn.backprop()

    print(nn.output)

權重不應該是 4x4 隨機矩陣,因為我們在隱藏層中有 4 個神經元和 4 個輸入值,所以權重的總數應該是 16,但是下面的代碼在 init function 中分配了一個 2x4 的矩陣並創建了一個點積?

您的輸入矩陣X表明樣本數為 4,特征數為 3。神經網絡輸入層中的神經元數等於特征數*,而不是樣本數。 例如,假設您有 4 輛汽車,您為每輛汽車選擇了 3 個特征:顏色、座位數和原產國。 對於每個汽車樣本,您將這 3 個特征輸入網絡並訓練您的 model。 即使你有 4000 個樣本,輸入神經元的數量也不會改變; 是 3。

因此self.weights1的形狀為(3, 4) ,其中 3 是特征數,4 是隱藏神經元數(這 4 與樣本數無關),正如預期的那樣。

*:有時輸入會增加1 (或-1 )以考慮偏差,因此在這種情況下輸入神經元的數量將是num_features + 1 但這是一個選擇是否單獨處理偏差。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM