简体   繁体   中英

Linear Regression with gradient descent: two questions

I'm trying to understand Linear Regression with Gradient Descent and I do not understand this part in my loss_gradients function below.

import numpy as np

def forward_linear_regression(X, y, weights):

    # dot product weights * inputs
    N = np.dot(X, weights['W'])

    # add bias
    P = N + weights['B']

    # compute loss with MSE
    loss = np.mean(np.power(y - P, 2))

    forward_info = {}
    forward_info['X'] = X
    forward_info['N'] = N
    forward_info['P'] = P
    forward_info['y'] = y

    return loss, forward_info

Here is where I'm stuck in my understanding, I have commented out my questions:

def loss_gradients(forward_info, weights):

    # to update weights, we need: dLdW = dLdP * dPdN * dNdW
    dLdP = -2 * (forward_info['y'] - forward_info['P'])
    dPdN = np.ones_like(forward_info['N'])
    dNdW = np.transpose(forward_info['X'], (1, 0))

    dLdW = np.dot(dNdW, dLdP * dPdN)
    # why do we mix matrix multiplication and dot product like this?
    # Why not dLdP * dPdN * dNdW instead?

    # to update biases, we need: dLdB = dLdP * dPdB
    dPdB = np.ones_like(forward_info[weights['B']])
    dLdB = np.sum(dLdP * dPdB, axis=0)
    # why do we sum those values along axis 0?
    # why not just dLdP * dPdB ?

It looks to me like this code is expecting a 'batch' of data. What I mean by that is, it's expecting that when you do forward_info and loss_gradients , you're actually passing a bunch of (X, y) pairs together. Let's say you pass B such pairs. The first dimension of all of your forward info stuff will have size B.

Now, the answers to both of your questions are the same: essentially, these lines compute the gradients (using the formulas you predicted) for each of the B terms , and then sum up all of the gradients so you get one gradient update. I encourage you to work out the logic behind the dot product yourself, because this is a very common pattern in ML, but it's a little tricky to get the hang of at first.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM