簡體   English   中英

具有部分混洗的Tensorflow數據集

[英]Tensorflow dataset with partial shuffle

根據文檔,我正在使用TensorFlow的數據集API,並且對shuffle()方法感到困惑:

Dataset.shuffle()轉換使用與tf.RandomShuffleQueue相似的算法對輸入數據集進行隨機混洗:它維護固定大小的緩沖區,並從該緩沖區中均勻地隨機選擇下一個元素。

如果我只是“部分地”改組我的數據集(例如,buffer_size <=元素數),我希望只有第一個buffer_size元素會被改組,但是事實並非如此,請參見示例:

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8])
                         .shuffle(buffer_size=4, seed=42)
                         .batch(2)
iter = dataset.make_initializable_iterator() # create the iterator
el = iter.get_next()
with tf.Session() as sess:
    sess.run(iter.initializer) 
    print('batch:', sess.run(el))

輸出:

batch: [2 5]

為什么在這里5? 因為緩沖區大小只有4個? 前2個元素應該在1〜4之內吧? 我在這里想念什么?

謝謝

簡短的答案是,可以在任何時候(包括在創建批處理過程中)補充隨機緩沖區。

這是您的觀察可能發生的方式:

  • 數據集從數據中讀取前4個元素。 隨機緩沖區現在包含[1、2、3、4]
  • 您請求兩個元素(通過創建2個批次的數據集上的get_next())
  • 隨機數據集選擇2並將下一個元素讀取到隨機緩沖區中,該緩沖區現在包含[1、3、4、5]。
  • 隨機數據集從緩沖區中選取5個。
  • 您的批次[2,5]被退回。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM