繁体   English   中英

在tensorflow数据集api中优化shuffle缓冲区大小

[英]Optimizing shuffle buffer size in tensorflow dataset api

我正在尝试使用dataset api来加载数据,并发现我花费了大部分时间将数据加载到shuffle缓冲区中。 我如何优化此管道以最小化填充shuffle缓冲区所花费的时间。

(tf.data.Dataset.list_files(path)
   .shuffle(num_files)  # number of tfrecord files 
   .apply(tf.contrib.data.parallel_interleave(lambda f: tf.data.TFRecordDataset(f), cycle_length=num_files))
   .shuffle(num_items)  # number of images in the dataset
   .map(parse_func, num_parallel_calls=8)
   .map(get_patches, num_parallel_calls=8)
   .apply(tf.contrib.data.unbatch())
   # Patch buffer is currently the number of patches extracted per image
   .apply(tf.contrib.data.shuffle_and_repeat(patch_buffer))
   .batch(64)
   .prefetch(1)
   .make_one_shot_iterator())

由于我有至多数千张图像,我对此问题的解决方案是每张图像都有一个单独的tfrecord文件。 这样,个人图像可以被洗牌,而不必先将它们加载到内存中。 这大大减少了需要发生的缓冲。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM