简体   繁体   中英

Are we loosing data when we use .next() or .take() on tf.keras.preprocessing.image_dataset_from_directory object?

I create a data generator like this:

# Create test_dataset
test_dataset = \
  tf.keras.preprocessing.image_dataset_from_directory(directory=test_dir,
                                                      labels='inferred', 
                                                      label_mode='int', 
                                                      class_names=None,
                                                      seed=42, 
                                                      )
# Explore the first batch
for images, labels in test_dataset.take(1):
  print(labels)

it returns:

tf.Tensor([5 3 8 3 8 5 7 6 3 8 4 2 4 5 5 4 0 1 0 5 5 2 6 0 7 9 9 0 4 9 6 4], shape=(32,), dtype=int32)

if I re-run the last part as below:

for images, labels in test_dataset.take(1):
  print(labels)

it returns something different from the first time:

tf.Tensor([0 6 2 5 5 7 5 2 7 4 0 5 0 4 6 5 8 7 7 3 5 1 1 9 5 2 6 6 6 6 2 0], shape=(32,), dtype=int32)

if I recreate test_dataset and explore it as below:

# Create test_dataset
test_dataset = \
  tf.keras.preprocessing.image_dataset_from_directory(directory=test_dir,
                                                      labels='inferred', 
                                                      label_mode='int', 
                                                      class_names=None,
                                                      seed=42, 
                                                      )
# Explore the first batch
for images, labels in test_dataset.take(1):
  print(labels)

it returns the same as the first time:

tf.Tensor([5 3 8 3 8 5 7 6 3 8 4 2 4 5 5 4 0 1 0 5 5 2 6 0 7 9 9 0 4 9 6 4], shape=(32,), dtype=int32)

Well, I conclude that when I use the take method, the batch is popped out and lost and no more accessible to be used in the modeling and validation, etc.

My question is:

  • Am I right? Is the first batch lost if I run test_dataset.take(1)
  • If the answer to the above question is yes, is there any way not to loose a bacth when trying to explore batches in tf.keras.preprocessing.image_dataset_from_directory object?

That's not about losing the batch. Function tf.keras.preprocessing.image_dataset_from_directory has an argument shuffle that is default value is True . That said, dataset is shuffled at each iteration.

If we dive into the source code :

  if shuffle:
    # Shuffle locally at each iteration
    dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed)
  dataset = dataset.batch(batch_size)

Under the hood as you can see it creates a tf.data object which has shuffle method. Shuffle Method has an argument reshuffle_each_iteration = True by default. With 2nd take method you are iterating over the dataset again that causes it to get shuffled again.

If you set shuffle = False for the dataset, then the data will be sorted in a alphanumeric order and its order won't change at each iteration.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM