简体   繁体   English

如何在 tf 2.1.0 中创建 tf.data.Dataset 的训练、测试和验证拆分

[英]how to create train, test & validation split of tf.data.Dataset in tf 2.1.0

the following code is copied from : https://www.tensorflow.org/tutorials/load_data/images以下代码复制自: https : //www.tensorflow.org/tutorials/load_data/images

the code aims to create dataset of images downloaded from the web and stored into folders depending upon their classes, please do refer to the link above for the whole context!该代码旨在创建从网络下载的图像数据集,并根据它们的类存储到文件夹中,请参阅上面的链接了解整个上下文!

list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'))

for f in list_ds.take(5):
  print(f.numpy())

def get_label(file_path):
  # convert the path to a list of path components
  parts = tf.strings.split(file_path, os.path.sep)
  # The second to last is the class-directory
  return parts[-2] == CLASS_NAMES

def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
  img = tf.image.decode_jpeg(img, channels=3)
  # Use `convert_image_dtype` to convert to floats in the [0,1] range.
  img = tf.image.convert_image_dtype(img, tf.float32)
  # resize the image to the desired size.
  return tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])

def process_path(file_path):
  label = get_label(file_path)
  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img, label

# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)

for image, label in labeled_ds.take(1):
  print("Image shape: ", image.numpy().shape)
  print("Label: ", label.numpy())

def prepare_for_training(ds, cache=True, shuffle_buffer_size=1000):
  # This is a small dataset, only load it once, and keep it in memory.
  # use `.cache(filename)` to cache preprocessing work for datasets that don't
  # fit in memory.
  if cache:
    if isinstance(cache, str):
      ds = ds.cache(cache)
    else:
      ds = ds.cache()

  ds = ds.shuffle(buffer_size=shuffle_buffer_size)

  # Repeat forever
  ds = ds.repeat()

  ds = ds.batch(BATCH_SIZE)

  # `prefetch` lets the dataset fetch batches in the background while the model
  # is training.
  ds = ds.prefetch(buffer_size=AUTOTUNE)

  return ds

train_ds = prepare_for_training(labeled_ds)

we are finally left with train_ds that is a PreffetchDataset object and contains the entire dataset of images, labels!我们最终留下了train_ds ,它是一个 PreffetchDataset 对象,包含图像、标签的整个数据集! How to split train_ds into train, test & validation sets to feed it into a model?如何将train_ds拆分为训练、测试和验证集以将其输入模型?

After the ds.repeat() call the dataset is infinite and splitting an infinte dataset doesn't work very well.ds.repeat()调用之后,数据集是无限的,拆分无限数据集效果不佳。 Therefore you should split it before the prepare_training() call.因此,您应该在prepare_training()调用之前将其拆分。 Like this:像这样:

labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)
labeled_ds = labeled_ds.shuffle(10000).batch(BATCH_SIZE)

# Size of dataset
n = sum(1 for _ in labeled_ds)
n_train = int(n * 0.8)
n_valid = int(n * 0.1)
n_test = n - n_train - n_valid

train_ds = labeled_ds.take(n_train)
valid_ds = labeled_ds.skip(n_train).take(n_valid)
test_ds = labeled_ds.skip(n_train + n_valid).take(n_test)

The line n = sum(1 for _ in labeled_ds) iterates through the dataset once to get its size, then it is 3-way split into 80%/10%/10%.n = sum(1 for _ in labeled_ds)遍历数据集一次以获取其大小,然后将其 3 路拆分为 80%/10%/10%。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM