繁体   English   中英

为什么我这么简单的线性回归不起作用

[英]Why isn't my so simple linear regression working

我是 tensorflow-2 的新手,我开始学习曲线,遵循简单的线性回归 model:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt


# Make data
num_samples, w, b = 20, 0.5, 2
xs = np.asarray(range(num_samples))
ys = np.asarray([x*w + b + np.random.normal() for x in range(num_samples)])
xts = tf.convert_to_tensor(xs, dtype=tf.float32)
yts = tf.convert_to_tensor(xs, dtype=tf.float32)
plt.plot(xs, ys, 'ro')

class Linear(tf.keras.Model):
    def __init__(self, name='linear', **kwargs):
        super().__init__(name='linear', **kwargs)
        self.w = tf.Variable(0, True, name="w", dtype=tf.float32)
        self.b = tf.Variable(1, True, name="b", dtype=tf.float32)   

    def call(self, inputs):
        return self.w*inputs + self.b

class Custom(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if epoch % 20 == 0:
            preds = self.model.predict(xts)
            plt.plot(xs, preds, label='{} {:7.2f}'.format(epoch, logs['loss']))
            print('The average loss for epoch {} is .'.format(epoch, logs['loss']))

x = tf.keras.Input(dtype=tf.float32, shape=[])
#model = tf.keras.Sequential([tf.keras.layers.Dense(units=1, input_shape=[1])])
model = Linear()
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='MSE')
model.fit(x=xts, y=yts, verbose=1, batch_size=4, epochs=250, callbacks=[Custom()])

plt.legend()
plt.show()

由于我不明白的原因,我的 model 似乎不符合曲线。 我也尝试了 keras.layers.Dense(1) 并且得到了完全相同的结果。 此外,结果似乎与适当的损失 function 不对应,因为在纪元 120 左右,model 的损失应该小于纪元 250 时的损失。

绝望的彩虹

你能帮我理解我做错了什么吗? 非常感谢!

您的代码中有一个小错误,因为xtsyts彼此相同,即您写了

xts = tf.convert_to_tensor(xs, dtype=tf.float32)
yts = tf.convert_to_tensor(xs, dtype=tf.float32)

代替

xts = tf.convert_to_tensor(xs, dtype=tf.float32)
yts = tf.convert_to_tensor(ys, dtype=tf.float32)

这就是为什么损失没有意义的原因。 修复此问题后,结果如预期,请参阅下面的 plot。

在此处输入图像描述

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM