简体   繁体   中英

Transforming a tf.data.dataset

Let's say i have as source data a dataset of 32*32*3 images of type:

<DatasetV1Adapter shapes: {coarse_label: (), image: (32, 32, 3), label: ()}, types: {coarse_label: tf.int64, image: tf.uint8, label: tf.int64}>

After serializing the data i get:

<MapDataset shapes: {depth: (), height: (), image_raw: (), label: (), width: ()}, types: {depth: tf.int64, height: tf.int64, image_raw: tf.string, label: tf.int64, width: tf.int64}>

I can access each element using this piece of code:

for i in parsed_image_dataset.take(1):
  j=i['image_raw']
array_shape = e1['image'].numpy().shape
print(np.frombuffer(j.numpy(), dtype = 'uint8').reshape(array_shape))

where e1 has be generated using get_next in the original dataset.So as expected the print prints an identical image to the one pre-serialization.However instead of doing this element by element could i somehow transform my serialized dataset immediatly into the original uint8 one?

You can get the image in uint8 by following the below steps.

Create Serialized data.

list_ds = tf.data.Dataset.list_files("img_dir_path/*")

Create a function that will take the file_path as an argument and return the image in uint8 format.

def process_img(file_path):
  img = tf.io.read_file(file_path)

  img = tf.image.decode_jpeg(img, channels=3)
  return img

Use the map function to apply the above function to all the items in the list_ds object.

processed_images = list_ds.map(process_img)

processed_images will contain images in uint8 format for the given image directory.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM