简体   繁体   中英

Convert a tensorflow tf.data.Dataset FlatMapDataset to TensorSliceDataset

I want to pass a list of tf.Strings to the .map(_parse_function) function.

 def _parse_function(self, img_path):
        img_str = tf.read_file(img_path)
        img_decode = tf.image.decode_jpeg(img_str, channels=3)
        img_decode = tf.divide(tf.cast(img_decode , tf.float32),255)
        return img_decode

When the tf.data.Dataset is of type TensorSliceDataset ,

dataset_from_slices = tf.data.Dataset.from_tensor_slices((tensor_with_filenames))

I can simply do dataset_from_slices.map(_parse_function ), which works.

However, dataset_from_generator = tf.data.Dataset.from_generator(...) returns a Dataset which is an instance of FlatMapDataset type and dataset_from_generator.map(_parse_function) gives the following error:

InvalidArgumentError: Input filename tensor must be scalar, but had shape: [32]

If I change the first line to:

img_str = tf.read_file(img_path[0])

that also works but then I only get the first image, which is not what I am looking for. Any suggestions?

It sounds like the elements of your dataset_from_generator are batched. The simplest remedy is to use tf.contrib.data.unbatch() to convert them back into individual elements:

# Each element is a vector of strings.
dataset_from_generator = tf.data.Dataset.from_generator(...)

# Converts each vector of strings into multiple individual elements.
dataset = dataset_from_generator.apply(tf.contrib.data.unbatch())

dataset = dataset.map(_parse_function)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM