繁体   English   中英

有没有办法在 Tensorflow 的另一个数据集中使用 tf.data.Dataset?

[英]Is there a way to use tf.data.Dataset inside of another Dataset in Tensorflow?

我在做分割。 每个训练样本都有多个带有分割掩码的图像。 我正在尝试编写input_fn以将每个训练样本的所有掩码图像合并为一个。 我计划使用两个Datasets ,一个遍历样本文件夹,另一个将所有掩码作为一个大批量读取,然后将它们合并到一个张量。

调用嵌套的make_one_shot_iterator时出现错误。 我知道这种方法有点牵强,而且很可能不是为这种用途而设计的数据集。 但是,我应该如何解决这个问题,以避免使用 tf.py_func?

这是数据集的简化版本:

def read_sample(sample_path):
    masks_ds = (tf.data.Dataset.
        list_files(sample_path+"/masks/*.png")
        .map(tf.read_file)
        .map(lambda x: tf.image.decode_image(x, channels=1))
        .batch(1024)) # maximum number of objects
    masks = masks_ds.make_one_shot_iterator().get_next()

    return tf.reduce_max(masks, axis=0)

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds.map(read_sample)
# ...
sample = ds.make_one_shot_iterator().get_next()
# ...

如果嵌套数据集只有一个元素,您可以在嵌套数据集上使用tf.contrib.data.get_single_element()而不是创建迭代器:

def read_sample(sample_path):
    masks_ds = (tf.data.Dataset.list_files(sample_path+"/masks/*.png")
                .map(tf.read_file)
                .map(lambda x: tf.image.decode_image(x, channels=1))
                .batch(1024)) # maximum number of objects
    masks = tf.contrib.data.get_single_element(masks_ds)
    return tf.reduce_max(masks, axis=0)

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds = ds.map(read_sample)
sample = ds.make_one_shot_iterator().get_next()

此外,您可以使用tf.data.Dataset.flat_map()tf.data.Dataset.interleave()tf.contrib.data.parallel_interleave()转换tf.contrib.data.parallel_interleave()在函数内执行嵌套的Dataset计算并展平结果变成单个Dataset 例如,要获取单个Dataset中的所有样本:

def read_all_samples(sample_path):
    return (tf.data.Dataset.list_files(sample_path+"/masks/*.png")
            .map(tf.read_file)
            .map(lambda x: tf.image.decode_image(x, channels=1))
            .batch(1024)) # maximum number of objects

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds = ds.flat_map(read_all_samples)
sample = ds.make_one_shot_iterator().get_next()

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM