I want to pass a list of tf.Strings
to the .map(_parse_function)
function.
def _parse_function(self, img_path):
img_str = tf.read_file(img_path)
img_decode = tf.image.decode_jpeg(img_str, channels=3)
img_decode = tf.divide(tf.cast(img_decode , tf.float32),255)
return img_decode
When the tf.data.Dataset
is of type TensorSliceDataset
,
dataset_from_slices = tf.data.Dataset.from_tensor_slices((tensor_with_filenames))
I can simply do dataset_from_slices.map(_parse_function
), which works.
However, dataset_from_generator = tf.data.Dataset.from_generator(...)
returns a Dataset
which is an instance of FlatMapDataset
type and dataset_from_generator.map(_parse_function)
gives the following error:
InvalidArgumentError: Input filename tensor must be scalar, but had shape: [32]
If I change the first line to:
img_str = tf.read_file(img_path[0])
that also works but then I only get the first image, which is not what I am looking for. Any suggestions?
It sounds like the elements of your dataset_from_generator
are batched. The simplest remedy is to use tf.contrib.data.unbatch()
to convert them back into individual elements:
# Each element is a vector of strings.
dataset_from_generator = tf.data.Dataset.from_generator(...)
# Converts each vector of strings into multiple individual elements.
dataset = dataset_from_generator.apply(tf.contrib.data.unbatch())
dataset = dataset.map(_parse_function)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.