[英]Why my neural network isn't able to classify correctly?
我制作了一個簡單的 neural.network 將食物分類為雞蛋或肉類兩類,但是每次我訓練 model 時,盡管圖像發生變化,它都會給我一個恆定的結果,就像我第一次訓練它識別每個圖像一樣肉,它第二次將所有圖像識別為雞蛋,我不知道這是否是我的代碼中的錯誤。
我在這里讀取數據:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
directory,
labels="inferred",
label_mode="int",
class_names= None,
color_mode="rgb",
batch_size=32,
image_size=(256, 256),
seed=None,
validation_split=None,
subset=None,
interpolation="bilinear",
follow_links=False,
crop_to_aspect_ratio=False
)
這是我預測在展平數據后使用softmax
activate function 的地方:
def forward(x):
return tf.matmul(x,W) + b
def model(x):
x = flatten(x)
return activate(x)
def activate(x):
return tf.nn.softmax(forward(x))
使用cross_entropy
計算誤差
def cross_entropy(y_label, y_pred):
return (-tf.reduce_sum(y_label * tf.math.log(y_pred + 1.e-10)))
使用下降梯度修改值:
optimizer = tf.keras.optimizers.SGD(learning_rate=0.25)
def train_step(x, y ):
with tf.GradientTape() as tape:
#compute loss function
current_loss = cross_entropy( y, model(x))
# compute gradient of loss
#(This is automatic! Even with specialized funcctions!)
grads = tape.gradient( current_loss , [W,b] )
# Apply SGD step to our Variables W and b
optimizer.apply_gradients( zip( grads , [W,b] ) )
return current_loss.numpy()
最后,訓練 model:
W = tf.Variable(tf.zeros([196608, 2],tf.float32))
# Bias tensor
b = tf.Variable(tf.zeros([2],tf.float32))
loss_values=[]
accuracies = []
epochs = 100
for i in range(epochs):
j=0
# each batch has 50 examples
for x_train_batch, y_train_batch in train_ds:
j+=1
current_loss = train_step(x_train_batch/255.0, tf.one_hot(y_train_batch,2))
if j%500 == 0: #reporting intermittent batch statistics
print("epoch ", str(i), "batch", str(j), "loss:", str(current_loss) )
Update:
I have discovered that the problem is in the gradients, they are always zero except for the first time
可以肯定的是,您是否真的在檢查訓練數據並使用訓練步驟,在您提供的代碼中沒有訓練循環遍歷您的train_ds
,似乎您的行為更多地取決於初始化,因為您實際上並沒有訓練 model。
經過幾天的debug,發現grads是0,然后才知道為什么,因為softmax function,因為x_train的值很大所以給出了0和1,使得變化趨於0,為了四處走動,我只是將 forward(x) arguments 分成了一個大數
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.