![](/img/trans.png)
[英]Problem with multi-output and multi-loss functions in tensorflow in python 3
[英]Separate trainX and trainY using Tensorflow 2 input pipeline for multi-loss training
我正在构建一个网络,将一张图片作为输入并输出两个预测,并从两个预测中计算出损失 function。 所以:
输入(trainX): RGBD输出1(trainGT): GT输出2(trainError):错误
当我使用 TF 的标准输入管道时,它会压缩输入和目标。 但是,当我定义损失 function 时,我需要将两个目标分开。 这是我的输入管道代码:
@tf.function
def load_image(point_file):
name = tf.strings.split(point_file,'/')[-1]
GT = tf.io.read_file(GT_path + name)
RGBD = tf.io.read_file(RGBD_path + name)
error = tf.io.read_file(error_path + name)
GT = tf.image.decode_png(GT)
RGBD = tf.image.decode_png(RGBD)
error = tf.image.decode_png(error)[:, :, 0]
GT = tf.cast(GT, tf.float32)// 255.0
RGBD = tf.cast(RGBD, tf.float32)/ 255.0
error = tf.cast(error, tf.float32)
error = tf.reshape(error, 1024)
return RGBD, GT, error
train_dataset = tf.data.Dataset.list_files(point_path + 'train/*.png')
train_dataset = train_dataset.map(load_image,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE).repeat()
test_dataset = tf.data.Dataset.list_files(point_path+'test/*.png')
test_dataset = test_dataset.map(load_image)
test_dataset = test_dataset.batch(BATCH_SIZE)
后来在适合 function 我想做的是:
losses = {
"GT_output": loss_functions.weighted_dice_loss,
"error_output": tf.keras.losses.MeanSquaredError(),
}
lossWeights = {"GT_output": 0.9, "error_output": 0.1}
model.compile(optimizer='adam',
loss=losses, loss_weights=lossWeights,)
model.fit(x=trainX,
y={"GT_output": trainGT, "error_output": trainError},
validation_data=(testX,
{"GT_output": testGT, "error_output": testError}),
epochs=EPOCHS,
verbose=1)
但是有没有办法将 trainX、trainGT 和 trainError 从 train_dataset 中分离出来?
谢谢...
你可以做:
take_n = 1000
trainGT = train_dataset.take(take_n)
trainError = train_dataset.skip(take_n)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.