簡體   English   中英

將張量元組傳遞給多個損失函數時,Tensorflow/keras 僅使用.fit 傳遞元組中的一個張量

[英]Tensorflow/keras when passing tuple of tensors to multiple loss functions, only one tensor in tuple is being passed using .fit

訓練數據的形狀為:

pred: [
         ([batch], 52, 52, 3, 6),
         ([batch], 26, 26, 3, 6),
         ([batch], 13, 13, 3, 6),
    ]
true: [
        [
         ([batch], 52, 52, 3, 6),
         ([batch], 128, 4),
        ],
        [
         ([batch], 26, 26, 3, 6),
         ([batch], 128, 4),
        ],
        [
         ([batch], 13, 13, 3, 6),
         ([batch], 128, 4),
        ]
    ]

我有 3 個輸出,並通過將損失列表傳遞給 model.compile 為每個輸出指定了損失 function。 問題是在訓練時,損失 function 只接收到它應該接收的兩個值之一。 當我手動調用它時,它工作正常。

例如:與

model.compile(
    optimizer=tf.optimizers.Adam(), 
    loss = losses
)

這會產生正確的行為

losses[0](y[0], model(x)[0])

而output這個:

def loss(target, output):
    print(output.shape)
    print(target[0].shape)
    print(target[1].shape)

是:

(1, 52, 52, 18)
(1, 52, 52, 3, 6)
(1, 128, 4)

但是當使用以下方法調用損失時:

model.fit(x,y,batch_size=1)

我得到這些尺寸:

(1, 52, 52, 18)
(52, 52, 3, 6)

然后這個錯誤:

ValueError: slice index 1 of dimension 0 out of bounds

建議在調用 through.fit 時,一次只傳遞元組中的一個張量

我真的很想避免自定義訓練循環以利用回調等功能。所以有什么方法可以一次將兩個張量傳遞給元組中的損失 function ?

這重現了這個問題:

import tensorflow as tf
import numpy as np

x = tf.data.Dataset.from_tensor_slices(np.full((100,1), 1.0))
y0 = tf.data.Dataset.from_tensor_slices(np.full((100,1), 2.0))
y1 = tf.data.Dataset.from_tensor_slices(np.full((100,1), 3.0))
y = tf.data.Dataset.zip((y0, y1))

train_data = tf.data.Dataset.zip((x, y)) 

inputs = tf.keras.Input(shape=(1,))
outputs = tf.keras.layers.Dense(1)(inputs)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

def myloss(y_true, y_pred):
    y0 = y_true[0]
    y1 = y_true[1]
    return tf.reduce_mean((y0 - y_pred)**2 + 0.5*(y1 - y_pred)**2)

model.compile(tf.keras.optimizers.Adam(), loss=myloss)
model.fit(train_data, epochs=1)

這里的目標 y 是一個元組,但是當將它傳遞給損失 function model.fit 時,只包含元組的第一個元素。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM