繁体   English   中英

混洗后,Tensorflow数据集是否在具有数据集转换的历元之间进行混洗?

[英]Does Tensorflow Dataset shuffle between epochs with Dataset transforms after shuffle?

我正在TensorFlow管道上工作,在该管道中我将一堆信号加载到数据集中,对这些信号进行混洗,然后对信号进行窗口化处理,然后进行批处理和重复操作。 该数据集用于通过调用model.fit函数来训练tf.keras模型。 不要混洗信号窗口,这很重要,这就是为什么这是数据集转换的顺序。

我想知道信号的顺序是否会在各个时期之间被打乱? 我发现dataset.shuffle().batch().repeat()将使数据集在各个纪元之间进行混合,但这不适用于我的应用程序,因为在混合之后我需要进行窗口化和其他转换。

我正在使用TensorFlow 1.13.1版本。

#... some pre-processing on the signals 
signalList = [...] # a list of tuples (data, label)
dataset = tf.data.Dataset.from_generator(lambda: signalList)
dataset = dataset.shuffle(buffer_size=self.buffer_size)  ## will this shuffle be repeated??
dataset = dataset.map(...) # windowing and other transforms
dataset = dataset.batch()
dataset = dataset.repeat()

model.fit(dataset, ...)

编辑:我感兴趣的行为是我希望每个时代重新调整信号的顺序。 所以,如果我有3个信号

signal0=[window0_0,window0_1]
signal1=[window1_0,window1_1,window1_2]
signal2=[window2_0]

那么输出将如下所示:

tf.Tensor([signal0,signal2,signal1],...) # equivalent to tf.Tensor([window0_0,window0_1,window2_0,window1_0,window1_1,window1_2])
tf.Tensor([signal1,signal0,signal2],...) # equivalent to tf.Tensor([window1_0,window1_1,window1_2,window0_0,window0_1,window2_0]) 

在哪里转换datset.map(windowing).shuffle()。batch()。repeat()会产生这样的内容(我对此不感兴趣)

tf.Tensor([window0_1,window1_1,window2_0,window1_0,window0_0,window1_2])
tf.Tensor([window0_0,window1_2,window0_1,window2_0,window1_1,window1_0]) 

您可以向.shuffle()传递一个可选参数,以防止重新组合每个时期。

所以,如果我有一个像这样的数据集:

def gen():
  yield 1
  yield 2
  yield 3

ds = tf.data.Dataset.from_generator(gen, output_shapes=(), output_types=tf.int32)

然后做:

shuffled_and_batched = ds.shuffle(3).batch(3).repeat()

给出输出:

tf.Tensor([3 2 1], shape=(3,), dtype=int32)
tf.Tensor([1 3 2], shape=(3,), dtype=int32)
tf.Tensor([2 1 3], shape=(3,), dtype=int32)
tf.Tensor([3 1 2], shape=(3,), dtype=int32)
tf.Tensor([2 3 1], shape=(3,), dtype=int32)

每个时代重新排列我的3个元素。 我了解这是您要避免的行为。

相反,如果我这样做:

shuffled_and_batched = ds.shuffle(3, reshuffle_each_iteration=False).batch(3).repeat()

然后我得到输出:

tf.Tensor([1 3 2], shape=(3,), dtype=int32)
tf.Tensor([1 3 2], shape=(3,), dtype=int32)
tf.Tensor([1 3 2], shape=(3,), dtype=int32)
tf.Tensor([1 3 2], shape=(3,), dtype=int32)
tf.Tensor([1 3 2], shape=(3,), dtype=int32)

顺序改组一次,然后在每个时期重用。

经过一番调查,我意识到是的,即使在混洗之后和批处理之前还有其他变换, shuffle也会在每个时期之后调用。 我不确定这对管道意味着什么(例如,我不确定窗口是否在每个时期都被调用并且正在减慢处理速度),但是我创建了一个jupyter笔记本,在其中创建了一个小版本的管道

signalList = [...] # a list of tuples (data, label)
dataset = tf.data.Dataset.from_generator(lambda: signalList)
dataset = dataset.shuffle(buffer_size=self.buffer_size)  
dataset = dataset.map(...) # windowing and other transforms
dataset = dataset.batch()
dataset = dataset.repeat()

创建一个迭代器

iterator = dataset.make_one_shot_iterator()

并绘制了几个时期的信号

next_ = iterator.get_next()
for i in range(10):  # 10 epochs
    full_signal = []
    for j in range(29):  # 29 events for this epoch
        next_ = iterator.get_next()
        full_signal = np.concatenate((full_signal, next_[0][0]), axis=None)

    fig = plt.figure(figsize=(18, 5))
    plt.plot(full_signal)

并且看到信号看上去总是处于不同的顺序,这意味着它们在每个时期之后都被重新洗牌了。

如果有人有更详细的答案,他们可以在其中解释如何与DatasetAPI编译一起使用,或者他们可以弄清楚这些转换的顺序是否会减慢管道的速度,我将不胜感激!

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM