簡體   English   中英

如何使用自定義生成器使tf.data.Dataset.from_generator產生批處理

[英]How to make tf.data.Dataset.from_generator yield batches with a custom generator

我想使用tf.data API。 我期望的工作流程如下所示:

  • 輸入圖像是具有(batch_size, width, height, channels, frames)的5D張量

  • 第一層是3D卷積

我使用tf.data.from_generator函數創建一個迭代器。 后來我做了一個初始化的迭代器。

我的代碼如下所示:

def custom_gen():
   img = np.random.normal((width, height, channels, frames))
   yield(img, img) # I train an autoencoder, so the x == y`

dataset = tf.data.Dataset.batch(batch_size).from_generator(custom_generator)
iter = dataset.make_initializable_iterator()

sess = tf.Session()
sess.run(iter.get_next())

我希望iter.get_next()我生成具有批處理大小的5D張量。 但是,我什至嘗試在自己的custom_generator產生批處理大小,但它不起作用。 當我想用輸入形狀(batch_size, width, height, channels, frames)占位符初始化數據集時,遇到一個錯誤。

該示例中的Dataset構建過程Dataset不正確。 應當按照導入數據的官方指南中規定的順序進行:

  1. 應該調用基本數據集創建函數或靜態方法來建立原始數據 (例如,靜態方法from_slice_tensorsfrom_generatorlist_files等)。
  2. 此時,可以通過鏈接適配器方法(例如batch )來應用轉換

從而:

dataset = tf.data.Dataset.from_generator(custom_generator).batch(batch_size)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM