[英]Training a basic TensorFlow Model using the GradientTape
Simply for education purposes, I was trying to build upon the Basic training loops tutorial from the TensorFlow homepage to create a simple neural network that classifies points in the plane.仅出于教育目的,我试图在 TensorFlow 主页上的基本训练循环教程的基础上创建一个简单的神经网络,用于对平面中的点进行分类。
So, I have some points in [0,1]x[0,1]
stored in a tensor x
of shape (250, 2, 1)
and the corresponding labels (1. or 0.)
stored in a tensor y
of shape (250,1,1)
.因此,我将
[0,1]x[0,1]
中的一些点存储在形状为(250, 2, 1)
的张量x
中,并将相应的标签(1. or 0.)
存储在形状为y
的张量中(250,1,1)
。 Then I do然后我做
import tensorflow as tf
w0 = tf.Variable(tf.random.normal([4,2]), name = 'w0')
w1 = tf.Variable(tf.random.normal([1,4]), name = 'w1')
b1 = tf.Variable(tf.zeros([4,1]), name = 'b1')
b2 = tf.Variable(tf.zeros([1,1]), name = 'b2')
loss = tf.keras.losses.CategoricalCrossentropy()
def forward(x):
x0 = x
z1 = tf.matmul(w0, x0) + b1
x1 = tf.nn.relu(z1)
z2 = tf.matmul(w1, x1) + b2
x2 = tf.nn.sigmoid(z2)
return x2
with tf.GradientTape() as t:
current_loss = loss(y, forward(x))
gradients = t.gradient(current_loss, [b1, b2, w0, w1])
What I get is a list of tensors of the expected shape but only containing zeros.我得到的是一个预期形状的张量列表,但只包含零。 Anyone some advice?
有人给点建议吗?
The issue happens because the labels/predictions do not have the expected shapes.出现问题是因为标签/预测没有预期的形状。 In particular, the loss function tf.keras.losses.CategoricalCrossentropy expects labels to be provided in a one-hot representation, but your labels and predictions have shape
(250, 1, 1)
and the behaviour of the loss function is unclear in this situation. In particular, the loss function tf.keras.losses.CategoricalCrossentropy expects labels to be provided in a one-hot representation, but your labels and predictions have shape
(250, 1, 1)
and the behaviour of the loss function is unclear in this情况。 Using tf.keras.losses.BinaryCrossentropy instead should solve the problem. 改用 tf.keras.losses.BinaryCrossentropy应该可以解决问题。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.