简体   繁体   中英

Neural network for linear regression

I found this great source that matched the exact model I needed: http://ufldl.stanford.edu/tutorial/supervised/LinearRegression/

The important bits go like this.

You have a plot x->y. Each x-value is the sum of "features" or how I'll denote them, z .

So a regression line for the x->y plot would go h(SUM(z(subscript-i)) where h(x) is the regression line (function)

In this NN the idea is that each z-value gets assigned a weight in a way that minimizes the least squared error.

The gradient function is used to update weights to minimize error. I believe I may be back propagating incorrectly -- where I update the weights.

So I wrote some code, but my weights aren't being correctly updated.

I may have simply misunderstood a spec from that Stanford post, so that's where I need your help. Can anyone verify I have correctly implemented this NN?

My h(x) function was a simple linear regression on the initial data. In other words, the idea is that the NN will adjust weights so that all data points shift closer to this linear regression.

for (epoch = 0; epoch < 10000; epoch++){

    //loop number of games
    for (game = 1; game < 39; game++){
      sum = 0;
      int temp1 = 0;
      int temp2 = 0;
      //loop number of inputs
      for (i = 0; i < 10; i++){
        //compute sum = x
        temp1 += inputs[game][i] * weights[i];
      }

      for (i = 10; i < 20; i++){
        temp2 += inputs[game][i] * weights[i];
      }

      sum = temp1 - temp2;

      //compute error
      error += .5 * (5.1136 * (sum) + 1.7238 - targets[game]) * (5.1136 * (sum) + 1.7238 - targets[game]);
      printf("error = %G\n", error);
      //backpropogate
      for (i = 0; i < 20; i++){
        weights[i] = sum * (5.1136 * (sum) + 1.7238 - targets[game]); //POSSIBLE ERROR HERE
      }

    }

    printf("Epoch = %d\n", epoch);
    printf("Error = %G\n", error);


  }

Please check out Andrew Ng's Coursera . He is the professor of Machine Learning at Stanford and can explain the concept of Linear Regression to you better than any pretty much anyone else. You can learn the essentials for linear regression in the first lesson.

For linear regression, you are trying to minimize the cost function, which in this case is the sum of squared errors (predicted value - actual value)^2 and is achieved by gradient descent. Solving a problem like this does not require a Neural Network and using one would be rather inefficient.

For this problem, only two values are needed. If you think back to the equation for a line, y = mx + b, there are really only two aspects of a line that you need: The slope and the y-intercept. In linear regression you are looking for the slope and y-intercept that best fits the data.

In this problem, the two values can be represented by theta0 and theta1. theta0 is the y-intercept and theta1 is the slope.

This is the update function for Linear Regression:

在此处输入图片说明

Here, theta is a 2 x 1 dimensional vector with theta0 and theta1 inside of it. What you are doing is taking theta and subtracting the mean of the sum of errors multiplied by a learning rate alpha (usually small, like 0.1).

Let's say the real perfect fit for the line is at y = 2x + 3, but our current slope and y-intercept are both at 0. Therefore, the sum of errors will be negative , and when theta is subtracted from a negative number, theta will increase, moving your prediction closer to the correct value. And vice versa for positive numbers. This is a basic example of gradient descent, where you are descending down a slope to minimize the cost (or error) of the model.

This is the type of model you should be trying to implement in your model instead of a Neural Network, which is more complex. Try to gain an understanding of linear and logistic regression with gradient descent before moving on to Neural Networks.

Implementing a linear regression algorithm in C can be rather challenging, especially without vectorization . If you are looking to learn about how a linear regression algorithm works and aren't specifically looking to use C to make it, I recommend using something like MatLab or Octave (a free alternative) to implement it instead. After all, the examples from the post you found use the same format.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM