简体   繁体   English

张量流的线性回归

[英]linear regression using tensorflow

import tensorflow as tf

M = tf.Variable([0.01],tf.float32)
b = tf.Variable([1.0],tf.float32)

#inputs and outputs

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32) # actual value of y which we already know

Yp = M * x + b # y predicted value

#loss

squareR = tf.square(Yp - y)
loss =  tf.reduce_sum(squareR)

#optimize

optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

for i in range(1000):
    sess.run(train,{x:[1,2,3,4,5],y:[1.9,2.4,3.7,4.9,5.1]})
print(sess.run([M,b]))

output 输出

[array([ 0.88999945], dtype=float32), array([ 0.93000191], dtype=float32)]

Problem: when I am changing the values of x and y to 问题:当我将x和y的值更改为

x:[100,200,300,400,500],y:[19,24,37,49,51]

then the output is: 那么输出是:

[array([ nan], dtype=float32), array([ nan], dtype=float32)]

please help me out to get slope and y-intercept of linear model. 请帮助我获得线性模型的斜率和y截距。

Adding some print statements to your training loop, we can see what's going on during training: 在训练循环中添加一些打印语句,我们可以看到训练过程中发生了什么:

for i in range(1000):
    _, mm, bb = sess.run([train,M,b],{x:[100,200,300,400,500],y:[19,24,37,49,51]})
    print(mm, bb)
    if np.isnan(mm):
      break
print(sess.run([M,b]))

The output: 输出:

[ 1118.01000977] [ 4.19999981]
[-12295860.] [-33532.921875]
[  1.35243170e+11] [  3.68845632e+08]
[ -1.48755065e+15] [ -4.05696309e+12]
[  1.63616896e+19] [  4.46228634e+16]
[ -1.79963571e+23] [ -4.90810521e+20]
[  1.97943407e+27] [  5.39846559e+24]
[ -2.17719537e+31] [ -5.93781625e+28]
[  2.39471499e+35] [  6.53105210e+32]
[-inf] [-inf]
[ nan] [ nan]

That output means your training is diverging. 该输出意味着您的培训有所不同。 In this case, lowering the learning rate is one of the possible approaches to fix the problem. 在这种情况下,降低学习率是解决问题的可能方法之一。

Lowering the learning rate to 0.000001 works, these are the learned M and b after 1000 iterations: 将学习率降低到0.000001作品,这是经过1000次迭代的学习到的M和b:

[array([ 0.11159456], dtype=float32), array([ 1.01534212], dtype=float32)]

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM