简体   繁体   中英

How to implement and understand Pre-processing and Data augmentation with tensorflow_datasets (tfds)?

I'm learning segmentation and data augmentation based in this TF 2.0 tutorial that uses Oxford-IIIT Pets .

For pre-processing/data augmentation they provide a set of functions into a specific pipeline:

# Import dataset
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask

@tf.function
def load_image_train(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  if tf.random.uniform(()) > 0.5:
    input_image = tf.image.flip_left_right(input_image)
    input_mask = tf.image.flip_left_right(input_mask)

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)

train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

This code brought me several doubts given the tf syntax. To prevent me from just doing a ctrl C ctrl V and actually understanding how tensorflow works, I would like to ask some questions:

1) In normalize function, the line tf.cast(input_image, tf.float32) / 255.0 can be changed by tf.image.convert_image_dtype(input_image, tf.float32) ?

2) In normalize function it's possible to change my segmentation_mask values in tf.tensor format without changing to a numpy ? What I desire to do is to only work with two possible masks (0 and 1) and not with (0, 1 and 2). Using numpy I made something like this:

segmentation_mask_numpy = segmentation_mask.numpy()
segmentation_mask_numpy[(segmentation_mask_numpy == 2) | (segmentation_mask_numpy == 3)] = 0

It's possible to do this without a numpy transformation?

3) In load_image_train function they say that this function is doing data augmentation, but how? In my perspective they are changing the original image with a flip given a random number and not providing another image to the dataset based in the original image. So, the function goal is to change a image and not add to my dataset an aug_image keeping the original? If I'm correct how can I change this function to give an aug_image and keep my original image in the dataset?

4) In others questions such as How to apply data augmentation in TensorFlow 2.0 after tfds.load() and TensorFlow 2.0 Keras: How to write image summaries for TensorBoard they used a lot of .map() sequential calls or .map().map().cache().batch().repeat() . My question is: there is this necessity? Exist a more simple way to do this? I tried to read tf documentation, but without success.

5) You recommed to work with ImageDataGenerator from keras as presented here or this tf approach is better?

4 - The thing with these sequential calls is that they ease our work of manipulating the dataset to apply transformations and they also claim that is a more performatic way of loading and processing your data . Regarding the modularization/simplicity i guess that it does its job, since you can easily load , pass it over an entire preprocessing pipeline , shuffle , and iterate over batches of your data with a few lines of code .

train_dataset =tf.data.TFRecordDataset(filenames=train_records_paths).map(parsing_fn)
train_dataset = train_dataset.shuffle(buffer_size=12000)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.repeat()
# Create a test dataset
test_dataset = tf.data.TFRecordDataset(filenames=test_records_paths).map(parsing_fn)
test_dataset = test_dataset.batch(batch_size)
test_dataset = test_dataset.repeat(1)
# 
validation_steps = test_size / batch_size 
history = transferred_resnet50.fit(x=train_dataset,
                        epochs=epochs,
                        steps_per_epoch=steps_per_epoch,                        
                        validation_data=test_dataset,
                        validation_steps=validation_steps)

For instance, this is all i have to do in order to load my dataset and feed my model with preprocessed data.

3 - They defined a preprocessing function to which their dataset was mapped to, which means that every time that one requests a sample the map function will be applied, just like in my case that i've used a parsing function in order to parse my data from the TFRecord format before using:

def parsing_fn(serialized):
    features = \
        {
            'image': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64)            
        }

    # Parse the serialized data so we get a dict with our data.
    parsed_example = tf.io.parse_single_example(serialized=serialized,
                                             features=features)

    # Get the image as raw bytes.
    image_raw = parsed_example['image']

    # Decode the raw bytes so it becomes a tensor with type.
    image = tf.io.decode_jpeg(image_raw)
    
    image = tf.image.resize(image,size=[224,224])
    
    # Get the label associated with the image.
    label = parsed_example['label']
    
    # The image and label are now correct TensorFlow types.
    return image, label

( Another example ) - From the parsing function above i can use the code below in order to create a dataset, iterate through my test set images and plot them.

records_path = DATA_DIR+'/'+'TFRecords'+'/test/'+'test_0.tfrecord'
# Create a dataset
dataset = tf.data.TFRecordDataset(filenames=records_path)
# Parse the dataset using a parsing function 
parsed_dataset = dataset.map(parsing_fn)
# Gets a sample from the iterator
iterator = tf.compat.v1.data.make_one_shot_iterator(parsed_dataset) 

for i in range(100):
    image,label = iterator.get_next()
    img_array = image.numpy()
    img_array = img_array.astype(np.uint8)
    plt.imshow(img_array)
    plt.show()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM