简体   繁体   中英

Is there a way to use tf.data.Dataset inside of another Dataset in Tensorflow?

I'm doing segmentation. Each training sample have multiple images with segmentation masks. I'm trying to write input_fn to merge all mask images in to one for each training sample. I was planning on using two Datasets , one that iterates over samples folders and another that reads all masks as one large batch and then merges them to one tensor.

I'm getting an error when nested make_one_shot_iterator is called. I Know that this approach is a bit of a stretch and mostlikely datasets wheren't designed for such usage. But then how should I approach this problem so that I avoid using tf.py_func?

Here is a simplified version of the dataset:

def read_sample(sample_path):
    masks_ds = (tf.data.Dataset.
        .map(lambda x: tf.image.decode_image(x, channels=1))
        .batch(1024)) # maximum number of objects
    masks = masks_ds.make_one_shot_iterator().get_next()

    return tf.reduce_max(masks, axis=0)

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
# ...
sample = ds.make_one_shot_iterator().get_next()
# ...

If the nested dataset has only a single element, you can use tf.contrib.data.get_single_element() on the nested dataset instead of creating an iterator:

def read_sample(sample_path):
    masks_ds = (tf.data.Dataset.list_files(sample_path+"/masks/*.png")
                .map(lambda x: tf.image.decode_image(x, channels=1))
                .batch(1024)) # maximum number of objects
    masks = tf.contrib.data.get_single_element(masks_ds)
    return tf.reduce_max(masks, axis=0)

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds = ds.map(read_sample)
sample = ds.make_one_shot_iterator().get_next()

In addition, you can use the tf.data.Dataset.flat_map() , tf.data.Dataset.interleave() , or tf.contrib.data.parallel_interleave() transformationw to perform a nested Dataset computation inside a function and flatten the result into a single Dataset . For example, to get all of the samples in a single Dataset :

def read_all_samples(sample_path):
    return (tf.data.Dataset.list_files(sample_path+"/masks/*.png")
            .map(lambda x: tf.image.decode_image(x, channels=1))
            .batch(1024)) # maximum number of objects

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds = ds.flat_map(read_all_samples)
sample = ds.make_one_shot_iterator().get_next()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

粤ICP备18138465号  © 2020-2024 STACKOOM.COM