簡體   English   中英

在tensorflow數據集api中優化shuffle緩沖區大小

[英]Optimizing shuffle buffer size in tensorflow dataset api

我正在嘗試使用dataset api來加載數據,並發現我花費了大部分時間將數據加載到shuffle緩沖區中。 我如何優化此管道以最小化填充shuffle緩沖區所花費的時間。

(tf.data.Dataset.list_files(path)
   .shuffle(num_files)  # number of tfrecord files 
   .apply(tf.contrib.data.parallel_interleave(lambda f: tf.data.TFRecordDataset(f), cycle_length=num_files))
   .shuffle(num_items)  # number of images in the dataset
   .map(parse_func, num_parallel_calls=8)
   .map(get_patches, num_parallel_calls=8)
   .apply(tf.contrib.data.unbatch())
   # Patch buffer is currently the number of patches extracted per image
   .apply(tf.contrib.data.shuffle_and_repeat(patch_buffer))
   .batch(64)
   .prefetch(1)
   .make_one_shot_iterator())

由於我有至多數千張圖像,我對此問題的解決方案是每張圖像都有一個單獨的tfrecord文件。 這樣,個人圖像可以被洗牌,而不必先將它們加載到內存中。 這大大減少了需要發生的緩沖。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM