簡體   English   中英

梯度下降法在線性回歸中效果不佳?

[英]Gradient descent method does not work well in linear regression?

在 R 中,我生成了一些人工數據以使用梯度下降法進行線性回歸 Y = c0 + c1 * x1 + c2 * x2 + 噪聲

我還使用解析方法來計算參數 theta = [c0, c1, c2]。 下面是帶有變量注釋的 R 代碼。

我使用梯度下降法來計算 theta。 公式取自下面的鏈接。

一個人的幻燈片

斯坦福的幻燈片 - Andrew Ng

但是,該方法無法收斂。 我的 R 代碼如下。 theta 與 R 代碼中的解析解 k 非常不同。

rm(list = ls())

n=500
x1=rnorm(n,mean=4,sd=1.6)
x2=rnorm(n,mean=4,sd=2.5)


X=cbind(x1,x2)
A=as.matrix(cbind(rep(1,n),x1,x2))
Y=-3.9+3.8*x1-2.4*x2+rnorm(n,mean=0,sd=1.5);


k=solve(t(A)%*%A,t(A)%*%Y) # k is the parameters determined by analytical method
MSE=sum((A%*%k-Y)^2)/(n);

iterations=3000 # total number of step
epsilon = 0.0001 # set precision
eta=0.0001 # step size

t1=integer(iterations)
e1=integer(iterations)

X=as.matrix(X)# convert data table X into a matrix
N=dim(X)[1] # total number of observations
X=as.matrix(cbind(rep(1,length(N)),X))# add a column of ones to represent intercept
np=dim(X)[2] # number of parameters to be determined
theta=matrix(rnorm(n=np,mean=0,sd=1),1,np) # Initialize theta:1 x np matrix
for(i in 1:iterations){
  error =theta%*%t(X)-t(Y) # error = (theta * x' -Y'). Error is a 1xN row vector;
  grad=(1/N)*error%*%X # Gradient grad is 1 x np vector
  theta=theta-eta*grad # updating theta
  L=sqrt(sum((eta*grad)^2)) # calculating the L2 norm
  e1[i]=sum((error)^2)/(2*N) # record the cost function in each step (value=2*MSE)
  t1[i]=L # record the L2 norm in each step
  if(L<=epsilon){ # checking whether convergence is obtained or not
    break
  }
}

plot(e1*2,type="l",ylab="MSE",lwd=2,col=rgb(0,0,1))
abline(h=MSE)
text(x=1000,y=MSE+1,labels = "Actual MSE",adj=1)
text(x=500,y=15,labels = "Gradient Descent",adj=0.4)
theta
k

我已經嘗試過代碼,問題是經過 3000 次迭代后,L2 范數 L 仍然大於精度 epsilon。 我可以通過首先運行set.seed(5556)來獲得可重現的示例數據。 經過 7419 次迭代后,L = 9.99938 e-5 < epsilon。 然而 theta 仍然與預期結果不同,當然對於 L2 范數和正常數據可以運行lm(Y ~ X1+X2)計算。

您的問題是您將eta設置為非常保守。 因此,收斂需要很長時間。 Eta 對收斂速度至關重要。 但是,如果選擇較大,則算法可能不會收斂。 您可能會在稍后的課程中了解自動調整 eta 的算法,如 Adagrad 或 Adam 如果您選擇 eta = 0.001 在此處輸入圖像描述

η= 0.01

在此處輸入圖像描述

η=0.1

在此處輸入圖像描述

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM