繁体   English   中英

使用 Tensorflow 2 输入管道分离 trainX 和 trainY 用于多损失训练

[英]Separate trainX and trainY using Tensorflow 2 input pipeline for multi-loss training

我正在构建一个网络,将一张图片作为输入并输出两个预测,并从两个预测中计算出损失 function。 所以:

输入(trainX): RGBD输出1(trainGT): GT输出2(trainError):错误

当我使用 TF 的标准输入管道时,它会压缩输入和目标。 但是,当我定义损失 function 时,我需要将两个目标分开。 这是我的输入管道代码:

@tf.function
def load_image(point_file):
    name = tf.strings.split(point_file,'/')[-1]
    GT = tf.io.read_file(GT_path + name)
    RGBD = tf.io.read_file(RGBD_path + name)
    error = tf.io.read_file(error_path + name)
    GT = tf.image.decode_png(GT)
    RGBD = tf.image.decode_png(RGBD)
    error = tf.image.decode_png(error)[:, :, 0]

    GT = tf.cast(GT, tf.float32)// 255.0
    RGBD = tf.cast(RGBD, tf.float32)/ 255.0
    error = tf.cast(error, tf.float32)
    error = tf.reshape(error, 1024)

    return RGBD, GT, error

train_dataset = tf.data.Dataset.list_files(point_path + 'train/*.png')
train_dataset = train_dataset.map(load_image,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE).repeat()

test_dataset = tf.data.Dataset.list_files(point_path+'test/*.png')
test_dataset = test_dataset.map(load_image)
test_dataset = test_dataset.batch(BATCH_SIZE)

后来在适合 function 我想做的是:

losses = {
    "GT_output": loss_functions.weighted_dice_loss,
    "error_output": tf.keras.losses.MeanSquaredError(),
}
lossWeights = {"GT_output": 0.9, "error_output": 0.1}

model.compile(optimizer='adam',
              loss=losses, loss_weights=lossWeights,)

model.fit(x=trainX,
    y={"GT_output": trainGT, "error_output": trainError},
    validation_data=(testX,
        {"GT_output": testGT, "error_output": testError}),
    epochs=EPOCHS,
    verbose=1)

但是有没有办法将 trainX、trainGT 和 trainError 从 train_dataset 中分离出来?

谢谢...

你可以做:

take_n = 1000
trainGT = train_dataset.take(take_n)
trainError = train_dataset.skip(take_n)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM