简体   繁体   中英

Linear Regression Stochastic Gradient Descent

I am trying to fit a sinusoidal wave (sin(2 pi x)) with some gaussian noise added to it. I am using the stochastic gradient descent algorithm, and the model I am trying to fit is linear in the parameters. I have used a simple basis function of [1 x^1 x^2... x^5] . The loss function is least squared loss.

def gradient_descent(phi, Y, W, a):
    N = len(Y)
    for i in range(N):
        dE_dW = (np.matmul(np.array([W]), np.array([phi[i]]).T)[0][0] - Y[i]) * phi[i]
        W = W - a * dE_dW
    return W

For sampling I am doing this,

noise_sample = np.random.normal(loc = 0, scale = 0.07, size = sample_size)
for i in range(sample_size):
   x = random.uniform(0.0, 0.5)
   y = sin(x)
   X.append(x), Y.append(y)
X, Y = np.array(X), np.array(Y)
permutation = np.random.permutation(sample_size)
X, Y = X[permutation], Y[permutation]
Y = np.add(Y, noise_sample)

order = 5
phi = np.array([np.ones(sample_size)]).T
for i in range(order):
   phi = np.c_[phi, X ** (i + 1)]
W = np.random.uniform(low=0.0, high=1.0, size=(order+1,))

I am getting this as the fitted curve in this case (orange). 罪(2pix)

When I try for the same degree using the closed form solution,

phi_inv = np.matmul(np.linalg.inv(np.matmul(phi.T, phi)), phi.T)
weights = np.matmul(phi_inv, Y.T)

I am getting the desired curve. Is there something I am doing wrong?

This can be an issue of a too large step size/learning rate a . The gradient you are computing is just a noisy version of the true gradient. If your step size is too large, your just jumping around almost randomly. Of course, if you chose it too small you will never reach the optimum and just stay close to where you have started.

You can start with some larger value for the step size and decrease it over time. You can also iterate over your training set mutliple times and or compute the gradient based on mini batches, Ie a small subset of all samples. In any case, try to check if the gradient vanishes over time too see if you are converging. Also check your loss function if it goes down.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM