简体   繁体   English

线性回归的梯度下降不收敛

[英]Gradient descent on linear regression not converging

I have implemented a very simple linear regression with gradient descent algorithm in JavaScript, but after consulting multiple sources and trying several things, I cannot get it to converge. 我已经在JavaScript中使用梯度下降算法实现了一个非常简单的线性回归,但是在咨询了多个资源并尝试了多种方法之后,我无法使其收敛。

The data is absolutely linear, it's just the numbers 0 to 30 as inputs with x*3 as their correct outputs to learn. 数据绝对是线性的,只是数字0到30作为输入,而x * 3是要学习的正确输出。

This is the logic behind the gradient descent: 这是梯度下降背后的逻辑:

train(input, output) {
  const predictedOutput = this.predict(input);
  const delta = output - predictedOutput;

  this.m += this.learningRate * delta * input;
  this.b += this.learningRate * delta;
}

predict(x) {
  return x * this.m + this.b;
}

I took the formulas from different places, including: 我从不同的地方获取了公式,包括:

I have already tried: 我已经尝试过:

  • normalizing input and output values to the [-1, 1] range 将输入和输出值归一化到[-1,1]范围
  • normalizing input and output values to the [0, 1] range 将输入和输出值归一化到[0,1]范围
  • normalizing input and output values to have mean = 0 and stddev = 1 将输入和输出值标准化为均值= 0和stddev = 1
  • reducing the learning rate (1e-7 is as low as I went) 降低学习率(1e-7和我一样低)
  • having a linear data set with no bias at all ( y = x * 3 ) 具有完全没有偏差的线性数据集( y = x * 3
  • having a linear data set with non-zero bias ( y = x * 3 + 2 ) 具有非零偏差的线性数据集( y = x * 3 + 2
  • initializing the weights with random non-zero values between -1 and 1 使用介于-1和1之间的随机非零值初始化权重

Still, the weights ( this.b and this.m ) do not approach any of the data values, and they diverge into infinity. 权重( this.bthis.m )仍未接近任何数据值,它们发散到无穷大。

I'm obviously doing something wrong, but I cannot figure out what it is. 我显然做错了,但是我无法弄清楚是什么。


Update: Here's a little bit more context that may help figure out what my problem is exactly: 更新:这里有一些更多的上下文可能有助于弄清楚我的问题到底是什么:

I'm trying to model a simple approximation to a linear function, with online learning by a linear regression pseudo-neuron. 我正在尝试通过线性回归伪神经元在线学习对线性函数的简单近似进行建模。 With that, my parameters are: 这样,我的参数是:

  • weights: [ this.m , this.b ] 重量:[ this.mthis.b ]
  • inputs: [ x , 1 ] 输入:[ x1 ]
  • activation function: identity function z(x) = x 激活函数:身份函数z(x) = x

As such, my net will be expressed by y = this.m * x + this.b * 1 , simulating the data-driven function that I want to approximate ( y = 3 * x ). 这样,我的网络将由y = this.m * x + this.b * 1 ,模拟我想近似的数据驱动函数( y = 3 * x )。

What I want is for my network to "learn" the parameters this.m = 3 and this.b = 0 , but it seems I get stuck at a local minima. 我想要的是让我的网络“学习”参数this.m = 3this.b = 0 ,但是似乎我陷入了局部最小值。

My error function is the mean-squared error: 我的误差函数是均方误差:

error(allInputs, allOutputs) {
  let error = 0;
  for (let i = 0; i < allInputs.length; i++) {
    const x = allInputs[i];
    const y = allOutputs[i];
    const predictedOutput = this.predict(x);
    const delta = y - predictedOutput;

    error += delta * delta;
  }

  return error / allInputs.length;
}

My logic for updating my weights will be (according to the sources I've checked so far) wi -= alpha * dError/dwi 我更新体重的逻辑将是(根据到目前为止我检查过的资源) wi -= alpha * dError/dwi

For the sake of simplicity, I'll call my weights this.m and this.b , so we can relate it back to my JavaScript code. 为了简单起见,我将权重this.mthis.b ,以便将其与我的JavaScript代码相关联。 I'll also call y^ the predicted value. 我也将y^称为预测值。

From here: 从这里:

error = y - y^
      = y - this.m * x + this.b

dError/dm = -x
dError/db = 1

And so, applying that to the weight correction logic: 因此,将其应用于权重校正逻辑:

this.m += alpha * x
this.b -= alpha * 1

But this doesn't seem correct at all. 但这似乎根本不正确。

I finally found what's wrong, and I'm answering my own question in hopes it will help beginners in this area too. 我终于找到了问题所在,我在回答自己的问题,希望它也会对这方面的初学者有所帮助。

First, as Sascha said, I had some theoretical misunderstandings. 首先,正如Sascha所说,我有一些理论上的误解。 It may be correct that your adjustment includes the input value verbatim, but as he said, it should already be part of the gradient. 您的调整逐字包含输入值可能是正确的,但正如他所说,它应该已经是渐变的一部分。 This all depends on your choice of the error function. 这完全取决于您对误差函数的选择。

Your error function will be the measure of what you use to measure how off you were from the real value, and that measurement needs to be consistent. 误差函数将衡量您用来衡量与实际价值之间的差距的程度,并且该衡量需要保持一致。 I was using mean-squared-error as a measurement tool (as you can see in my error method), but I was using a pure-absolute error ( y^ - y ) inside of the training method to measure the error. 我使用均方误差作为测量工具(如您在error方法中所见),但是我在训练方法内部使用了纯绝对误差( y^ - y )来测量误差。 Your gradient will depend on the choice of this error function. 您的梯度将取决于此误差函数的选择。 So choose only one and stick with it. 因此,只选择一个并坚持下去。

Second, simplify your assumptions in order to test what's wrong . 其次, 简化您的假设以测试出什么问题了 In this case, I had a very good idea what the function to approximate was ( y = x * 3 ) so I manually set the weights ( this.b and this.m ) to the right values and I still saw the error diverge. 在这种情况下,我非常了解要近似的函数是( y = x * 3 ),因此我手动将权重( this.bthis.m )设置为正确的值,但我仍然看到误差发散。 This means that weight initialization was not the problem in this case. 这意味着在这种情况下权重初始化不是问题。

After searching some more, my error was somewhere else: the function that was feeding data into the network was mistakenly passing a 3 hardcoded value into the predicted output (it was using a wrong index in an array), so the oscillation I saw was because of the network trying to approximate to y = 0 * x + 3 ( this.b = 3 and this.m = 0 ), but because of the small learning rate and the error in the error function derivative, this.b wasn't going to get near to the right value, making this.m making wild jumps to adjust to it. 经过更多搜索之后,我的错误出现在其他地方:将数据馈入网络的函数错误地将3硬编码值传递给预测的输出(它在数组中使用了错误的索引),所以我看到的振荡是因为试图近似于y = 0 * x + 3this.b = 3this.m = 0 )的网络,但是由于学习率低和误差函数导数中的误差,因此this.b不是即将接近正确的值,使this.m进行大量跳跃以适应该值。

Finally, keep track of the error measurement as your network trains , so you can have some insight into what's going on. 最后, 在网络训练过程中跟踪错误度量 ,以便您对正在发生的事情有一些了解。 This helps a lot to identify a difference between simple overfitting, big learning rates and plain simple mistakes. 这有助于发现简单的过度拟合,高学习率和简单的简单错误之间的区别。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM