[英]Does Tensorflow Dataset shuffle between epochs with Dataset transforms after shuffle?
我正在TensorFlow管道上工作,在該管道中我將一堆信號加載到數據集中,對這些信號進行混洗,然后對信號進行窗口化處理,然后進行批處理和重復操作。 該數據集用於通過調用model.fit函數來訓練tf.keras模型。 不要混洗信號窗口,這很重要,這就是為什么這是數據集轉換的順序。
我想知道信號的順序是否會在各個時期之間被打亂? 我發現dataset.shuffle().batch().repeat()
將使數據集在各個紀元之間進行混合,但這不適用於我的應用程序,因為在混合之后我需要進行窗口化和其他轉換。
我正在使用TensorFlow 1.13.1版本。
#... some pre-processing on the signals
signalList = [...] # a list of tuples (data, label)
dataset = tf.data.Dataset.from_generator(lambda: signalList)
dataset = dataset.shuffle(buffer_size=self.buffer_size) ## will this shuffle be repeated??
dataset = dataset.map(...) # windowing and other transforms
dataset = dataset.batch()
dataset = dataset.repeat()
model.fit(dataset, ...)
編輯:我感興趣的行為是我希望每個時代重新調整信號的順序。 所以,如果我有3個信號
signal0=[window0_0,window0_1]
signal1=[window1_0,window1_1,window1_2]
signal2=[window2_0]
那么輸出將如下所示:
tf.Tensor([signal0,signal2,signal1],...) # equivalent to tf.Tensor([window0_0,window0_1,window2_0,window1_0,window1_1,window1_2])
tf.Tensor([signal1,signal0,signal2],...) # equivalent to tf.Tensor([window1_0,window1_1,window1_2,window0_0,window0_1,window2_0])
在哪里轉換datset.map(windowing).shuffle()。batch()。repeat()會產生這樣的內容(我對此不感興趣)
tf.Tensor([window0_1,window1_1,window2_0,window1_0,window0_0,window1_2])
tf.Tensor([window0_0,window1_2,window0_1,window2_0,window1_1,window1_0])
您可以向.shuffle()
傳遞一個可選參數,以防止重新組合每個時期。
所以,如果我有一個像這樣的數據集:
def gen():
yield 1
yield 2
yield 3
ds = tf.data.Dataset.from_generator(gen, output_shapes=(), output_types=tf.int32)
然后做:
shuffled_and_batched = ds.shuffle(3).batch(3).repeat()
給出輸出:
tf.Tensor([3 2 1], shape=(3,), dtype=int32)
tf.Tensor([1 3 2], shape=(3,), dtype=int32)
tf.Tensor([2 1 3], shape=(3,), dtype=int32)
tf.Tensor([3 1 2], shape=(3,), dtype=int32)
tf.Tensor([2 3 1], shape=(3,), dtype=int32)
每個時代重新排列我的3個元素。 我了解這是您要避免的行為。
相反,如果我這樣做:
shuffled_and_batched = ds.shuffle(3, reshuffle_each_iteration=False).batch(3).repeat()
然后我得到輸出:
tf.Tensor([1 3 2], shape=(3,), dtype=int32)
tf.Tensor([1 3 2], shape=(3,), dtype=int32)
tf.Tensor([1 3 2], shape=(3,), dtype=int32)
tf.Tensor([1 3 2], shape=(3,), dtype=int32)
tf.Tensor([1 3 2], shape=(3,), dtype=int32)
順序改組一次,然后在每個時期重用。
經過一番調查,我意識到是的,即使在混洗之后和批處理之前還有其他變換, shuffle
也會在每個時期之后調用。 我不確定這對管道意味着什么(例如,我不確定窗口是否在每個時期都被調用並且正在減慢處理速度),但是我創建了一個jupyter筆記本,在其中創建了一個小版本的管道
signalList = [...] # a list of tuples (data, label)
dataset = tf.data.Dataset.from_generator(lambda: signalList)
dataset = dataset.shuffle(buffer_size=self.buffer_size)
dataset = dataset.map(...) # windowing and other transforms
dataset = dataset.batch()
dataset = dataset.repeat()
創建一個迭代器
iterator = dataset.make_one_shot_iterator()
並繪制了幾個時期的信號
next_ = iterator.get_next()
for i in range(10): # 10 epochs
full_signal = []
for j in range(29): # 29 events for this epoch
next_ = iterator.get_next()
full_signal = np.concatenate((full_signal, next_[0][0]), axis=None)
fig = plt.figure(figsize=(18, 5))
plt.plot(full_signal)
並且看到信號看上去總是處於不同的順序,這意味着它們在每個時期之后都被重新洗牌了。
如果有人有更詳細的答案,他們可以在其中解釋如何與DatasetAPI編譯一起使用,或者他們可以弄清楚這些轉換的順序是否會減慢管道的速度,我將不勝感激!
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.