簡體   English   中英

為什么我這么簡單的線性回歸不起作用

[英]Why isn't my so simple linear regression working

我是 tensorflow-2 的新手,我開始學習曲線,遵循簡單的線性回歸 model:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt


# Make data
num_samples, w, b = 20, 0.5, 2
xs = np.asarray(range(num_samples))
ys = np.asarray([x*w + b + np.random.normal() for x in range(num_samples)])
xts = tf.convert_to_tensor(xs, dtype=tf.float32)
yts = tf.convert_to_tensor(xs, dtype=tf.float32)
plt.plot(xs, ys, 'ro')

class Linear(tf.keras.Model):
    def __init__(self, name='linear', **kwargs):
        super().__init__(name='linear', **kwargs)
        self.w = tf.Variable(0, True, name="w", dtype=tf.float32)
        self.b = tf.Variable(1, True, name="b", dtype=tf.float32)   

    def call(self, inputs):
        return self.w*inputs + self.b

class Custom(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if epoch % 20 == 0:
            preds = self.model.predict(xts)
            plt.plot(xs, preds, label='{} {:7.2f}'.format(epoch, logs['loss']))
            print('The average loss for epoch {} is .'.format(epoch, logs['loss']))

x = tf.keras.Input(dtype=tf.float32, shape=[])
#model = tf.keras.Sequential([tf.keras.layers.Dense(units=1, input_shape=[1])])
model = Linear()
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='MSE')
model.fit(x=xts, y=yts, verbose=1, batch_size=4, epochs=250, callbacks=[Custom()])

plt.legend()
plt.show()

由於我不明白的原因,我的 model 似乎不符合曲線。 我也嘗試了 keras.layers.Dense(1) 並且得到了完全相同的結果。 此外,結果似乎與適當的損失 function 不對應,因為在紀元 120 左右,model 的損失應該小於紀元 250 時的損失。

絕望的彩虹

你能幫我理解我做錯了什么嗎? 非常感謝!

您的代碼中有一個小錯誤,因為xtsyts彼此相同,即您寫了

xts = tf.convert_to_tensor(xs, dtype=tf.float32)
yts = tf.convert_to_tensor(xs, dtype=tf.float32)

代替

xts = tf.convert_to_tensor(xs, dtype=tf.float32)
yts = tf.convert_to_tensor(ys, dtype=tf.float32)

這就是為什么損失沒有意義的原因。 修復此問題后,結果如預期,請參閱下面的 plot。

在此處輸入圖像描述

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM