I am new to tensorflow-2 and I was starting my learning curve, with the follow simple Linear-Regression model:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# Make data
num_samples, w, b = 20, 0.5, 2
xs = np.asarray(range(num_samples))
ys = np.asarray([x*w + b + np.random.normal() for x in range(num_samples)])
xts = tf.convert_to_tensor(xs, dtype=tf.float32)
yts = tf.convert_to_tensor(xs, dtype=tf.float32)
plt.plot(xs, ys, 'ro')
class Linear(tf.keras.Model):
def __init__(self, name='linear', **kwargs):
super().__init__(name='linear', **kwargs)
self.w = tf.Variable(0, True, name="w", dtype=tf.float32)
self.b = tf.Variable(1, True, name="b", dtype=tf.float32)
def call(self, inputs):
return self.w*inputs + self.b
class Custom(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
if epoch % 20 == 0:
preds = self.model.predict(xts)
plt.plot(xs, preds, label='{} {:7.2f}'.format(epoch, logs['loss']))
print('The average loss for epoch {} is .'.format(epoch, logs['loss']))
x = tf.keras.Input(dtype=tf.float32, shape=[])
#model = tf.keras.Sequential([tf.keras.layers.Dense(units=1, input_shape=[1])])
model = Linear()
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='MSE')
model.fit(x=xts, y=yts, verbose=1, batch_size=4, epochs=250, callbacks=[Custom()])
plt.legend()
plt.show()
For a reason I don't understand it seems like my model is not fitting the curve. I also tried with keras.layers.Dense(1) and I had the same exact result. Also it seems like the results don't correspond to a proper loss function, as around epoch 120 the model should have less loss than on 250.
Can you maybe help me understand what I am doing wrong? Thanks a lot!
There is a small bug in your code as xts
and yts
are identical to each other, ie you wrote
xts = tf.convert_to_tensor(xs, dtype=tf.float32)
yts = tf.convert_to_tensor(xs, dtype=tf.float32)
instead of
xts = tf.convert_to_tensor(xs, dtype=tf.float32)
yts = tf.convert_to_tensor(ys, dtype=tf.float32)
which is why the loss doesn't make sense. Once this has been fixed the results are as expected, see the plot below.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.