简体   繁体   中英

Multivariate Linear Regression - Gradient Descent in R

I am learning machine learning. So I do some simple practice with the data I find online. Right now I try to implement linear regression by gradient descent in R. When I run it, I realize that it does not converge and my cost goes high infinitely. Although I suspect it is somewhere in the part where I calculate gradient, I am not able to find the problem. So lets start presenting my data.

My data set contains 4 column : ROLL ~ UNEM, HGRAD, INC So, the goal is finding relationship between ROLL and others.

  • Let me present my code

     datavar <- read.csv("dataset.csv") attach(datavar) X <- cbind(rep(1, 29), UNEM,HGRAD,INC) y <- ROLL # function where I calculate my prediction h <- function(X, theta){ return(t(theta) %*% X) } # function where I calculate the cost with current values cost <- function(X, y, theta){ result <- sum((X %*% theta - y)^2 ) / (2*length(y)) return(result) } # here I calculate the gradient, #mathematically speaking I calculate derivetive of cost function at given points gradient <- function(X, y, theta){ m <- nrow(X) sum <- c(0,0,0,0) for (i in 1 : m) { sum <- sum + (h(X[i,], theta) - y[i]) * X[i,] } return(sum) } # The main algorithm gradientDescent <- function(X, y, maxit){ alpha <- 0.005 m <- nrow(X) theta <- c(0,0,0,0) cost_history <- rep(0,maxit) for (i in 1 : maxit) { theta <- theta - alpha*(1/m)*gradient(X, y, theta) cost_history[i] <- cost(X, y, theta) } plot(1:maxit, cost_history, type = 'l') return(theta) } 

I run the code like this

 gradientDescent(X, y, 20)

This is the output I get :

-7.001406e+118  -5.427330e+119  -1.192040e+123  -1.956518e+122

So, can you find where I was wrong. I have already tried different alpha values, didn't make a difference. By the way, I appreciate any tips or good practice from you,

Thanks

Well, I think I finally found the answer. The problem was that I did not appy any feature scaling. Couse I though it was optional precedure for running the algorithm smoothly. Now it works as expected. You can try to run code with scaled dataset using R's scale() function.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM