簡體   English   中英

訓練期間 TensorFlow 2 中多輸入圖像分類任務中的 AssertionError

[英]AssertionError in multi-input image classification task in TensorFlow 2 during training

我有以下使用 tf.keras Functional API 構建的模型結構,其中整個想法是輸入同一項目的不同視圖的六個圖像並獲得單一分類:

在此處輸入圖片說明

我有兩個輸出類。 我在分類中使用 softmax 而不是 sigmoid,因為我想稍后分別檢查每個類,如果它們不同則更容易。

因此,我相信我的輸入形狀應該是[batch_size, 6, img_width, img_height, num_channels] ,盡管我不是 100% 確定。 因為我使用的是 one-hot 標簽,所以我的標簽形狀應該是[batch_size, num_classes] ,對嗎?

但是,對於輸入 X 和 Y 形狀:

Tensor("X:0", shape=(2, 6, 224, 224, 3), dtype=float32)
Tensor("Y:0", shape=(2, 2), dtype=float32)

TensorFlow 抱怨AssertionError: Could not compute output Tensor("dense_1/Softmax:0", shape=(None, 2), dtype=float32)

追溯這個錯誤,我可以看到它是從我的訓練函數開始的:

def train_function(X, Y, model, loss, optimizer, metric):
    with tf.GradientTape() as tape:
        predictions = model(X, training=True) #<- where the traceback occurs
        loss_value = loss(Y, predictions)

    gradients = tape.gradient(loss_value, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_acc = metric.update_state(Y, predictions)
    return loss_value

這個函數本身在訓練循環中被調用:

    for batch, (X, Y) in train_ds.enumerate():
        train_loss = train_function(X, Y, model, loss, optimizer, metric)  # do the actual training
        train_acc = metric.result().numpy()  # get the training accuracy
        batch_train_loss.append(train_loss)  # save the training loss above
        batch_train_acc.append(train_acc)  # save the training accuracy above
        metric.reset_states()  # reset the metric after every batch

其中train_ds與內置tf.data.Dataset進口TFRecords

你有什么建議可能是錯的嗎? 我唯一能想到的是輸入形狀不正確,但我不知道它應該是什么形狀。

這里的正確答案是給每個輸入一個名稱:

dorsal_image = tf.keras.layers.Input(shape=(None, None, 3), name="dorsal_input")
medial_image = tf.keras.layers.Input(shape=(None, None, 3), name="medial_input")
plantar_image = tf.keras.layers.Input(shape=(None, None, 3), name="plantar_input")
lateral_image = tf.keras.layers.Input(shape=(None, None, 3), name="lateral_input")
proximal_image = tf.keras.layers.Input(shape=(None, None, 3), name="proximal_input")
distal_image = tf.keras.layers.Input(shape=(None, None, 3), name="distal_input")

並在 TFRecord 處理中確保輸出是字典而不是列表:

return {'dorsal_input': image_lst[0], 'medial_input': image_lst[1],
        'plantar_input': image_lst[2], 'lateral_input': image_lst[3], 
        'proximal_input': image_lst[4], 'distal_input': image_lst[5]}, label

這實際上是通過查看 samcaetano 在這個 github 問題上的回答來解決的

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM