[英]How to build multi-input image process by Tensorflow ImageDataGenerator
[英]AssertionError in multi-input image classification task in TensorFlow 2 during training
我有以下使用 tf.keras Functional API 構建的模型結構,其中整個想法是輸入同一項目的不同視圖的六個圖像並獲得單一分類:
我有兩個輸出類。 我在分類中使用 softmax 而不是 sigmoid,因為我想稍后分別檢查每個類,如果它們不同則更容易。
因此,我相信我的輸入形狀應該是[batch_size, 6, img_width, img_height, num_channels]
,盡管我不是 100% 確定。 因為我使用的是 one-hot 標簽,所以我的標簽形狀應該是[batch_size, num_classes]
,對嗎?
但是,對於輸入 X 和 Y 形狀:
Tensor("X:0", shape=(2, 6, 224, 224, 3), dtype=float32)
Tensor("Y:0", shape=(2, 2), dtype=float32)
TensorFlow 抱怨AssertionError: Could not compute output Tensor("dense_1/Softmax:0", shape=(None, 2), dtype=float32)
追溯這個錯誤,我可以看到它是從我的訓練函數開始的:
def train_function(X, Y, model, loss, optimizer, metric):
with tf.GradientTape() as tape:
predictions = model(X, training=True) #<- where the traceback occurs
loss_value = loss(Y, predictions)
gradients = tape.gradient(loss_value, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_acc = metric.update_state(Y, predictions)
return loss_value
這個函數本身在訓練循環中被調用:
for batch, (X, Y) in train_ds.enumerate():
train_loss = train_function(X, Y, model, loss, optimizer, metric) # do the actual training
train_acc = metric.result().numpy() # get the training accuracy
batch_train_loss.append(train_loss) # save the training loss above
batch_train_acc.append(train_acc) # save the training accuracy above
metric.reset_states() # reset the metric after every batch
其中train_ds
與內置tf.data.Dataset
進口TFRecords
。
你有什么建議可能是錯的嗎? 我唯一能想到的是輸入形狀不正確,但我不知道它應該是什么形狀。
這里的正確答案是給每個輸入一個名稱:
dorsal_image = tf.keras.layers.Input(shape=(None, None, 3), name="dorsal_input")
medial_image = tf.keras.layers.Input(shape=(None, None, 3), name="medial_input")
plantar_image = tf.keras.layers.Input(shape=(None, None, 3), name="plantar_input")
lateral_image = tf.keras.layers.Input(shape=(None, None, 3), name="lateral_input")
proximal_image = tf.keras.layers.Input(shape=(None, None, 3), name="proximal_input")
distal_image = tf.keras.layers.Input(shape=(None, None, 3), name="distal_input")
並在 TFRecord 處理中確保輸出是字典而不是列表:
return {'dorsal_input': image_lst[0], 'medial_input': image_lst[1],
'plantar_input': image_lst[2], 'lateral_input': image_lst[3],
'proximal_input': image_lst[4], 'distal_input': image_lst[5]}, label
這實際上是通過查看 samcaetano 在這個 github 問題上的回答來解決的。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.