繁体   English   中英

将张量元组传递给多个损失函数时,Tensorflow/keras 仅使用.fit 传递元组中的一个张量

[英]Tensorflow/keras when passing tuple of tensors to multiple loss functions, only one tensor in tuple is being passed using .fit

训练数据的形状为:

pred: [
         ([batch], 52, 52, 3, 6),
         ([batch], 26, 26, 3, 6),
         ([batch], 13, 13, 3, 6),
    ]
true: [
        [
         ([batch], 52, 52, 3, 6),
         ([batch], 128, 4),
        ],
        [
         ([batch], 26, 26, 3, 6),
         ([batch], 128, 4),
        ],
        [
         ([batch], 13, 13, 3, 6),
         ([batch], 128, 4),
        ]
    ]

我有 3 个输出,并通过将损失列表传递给 model.compile 为每个输出指定了损失 function。 问题是在训练时,损失 function 只接收到它应该接收的两个值之一。 当我手动调用它时,它工作正常。

例如:与

model.compile(
    optimizer=tf.optimizers.Adam(), 
    loss = losses
)

这会产生正确的行为

losses[0](y[0], model(x)[0])

而output这个:

def loss(target, output):
    print(output.shape)
    print(target[0].shape)
    print(target[1].shape)

是:

(1, 52, 52, 18)
(1, 52, 52, 3, 6)
(1, 128, 4)

但是当使用以下方法调用损失时:

model.fit(x,y,batch_size=1)

我得到这些尺寸:

(1, 52, 52, 18)
(52, 52, 3, 6)

然后这个错误:

ValueError: slice index 1 of dimension 0 out of bounds

建议在调用 through.fit 时,一次只传递元组中的一个张量

我真的很想避免自定义训练循环以利用回调等功能。所以有什么方法可以一次将两个张量传递给元组中的损失 function ?

这重现了这个问题:

import tensorflow as tf
import numpy as np

x = tf.data.Dataset.from_tensor_slices(np.full((100,1), 1.0))
y0 = tf.data.Dataset.from_tensor_slices(np.full((100,1), 2.0))
y1 = tf.data.Dataset.from_tensor_slices(np.full((100,1), 3.0))
y = tf.data.Dataset.zip((y0, y1))

train_data = tf.data.Dataset.zip((x, y)) 

inputs = tf.keras.Input(shape=(1,))
outputs = tf.keras.layers.Dense(1)(inputs)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

def myloss(y_true, y_pred):
    y0 = y_true[0]
    y1 = y_true[1]
    return tf.reduce_mean((y0 - y_pred)**2 + 0.5*(y1 - y_pred)**2)

model.compile(tf.keras.optimizers.Adam(), loss=myloss)
model.fit(train_data, epochs=1)

这里的目标 y 是一个元组,但是当将它传递给损失 function model.fit 时,只包含元组的第一个元素。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM