[英]What does train_data.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat() do?
I am following the timeseries/LSTM tutorial for Tensorflow and struggle to understand what this line does as it is not really explained:我正在关注 Tensorflow 的 timeseries/LSTM 教程,并且很难理解这条线的作用,因为它没有得到真正的解释:
train_data.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
I tried to look up what the different modules do but I fail to understand the complete command and its effect on the dataset.我试图查看不同模块的作用,但我无法理解完整的命令及其对数据集的影响。 Here is the entire tutorial: Click
这是整个教程: 点击
It's an input pipeline definition based on the tensorflow.data
API.它是基于
tensorflow.data
API 的输入管道定义。 Breaking it down:分解它:
(train_data # some tf.data.Dataset, likely in the form of tuples (x, y)
.cache() # caches the dataset in memory (avoids having to reapply preprocessing transformations to the input)
.shuffle(BUFFER_SIZE) # shuffle the samples to have always a random order of samples fed to the network
.batch(BATCH_SIZE) # batch samples in chunks of size BATCH_SIZE (except the last one, that may be smaller)
.repeat()) # repeat forever, meaning the dataset will keep producing batches and never terminate running out of data.
Notes:笔记:
cache()
, the second iteration of the dataset will load data from the cache in memory instead than the previous steps of the pipeline.cache()
,数据集的第二次迭代将从 memory 中的缓存加载数据,而不是管道的前面步骤。 This saves you some time if the data preprocessing is complex (but, for big datasets, this may be very heavy on your memory)BUFFER_SIZE
is the number of items in the shuffle buffer. BUFFER_SIZE
是随机播放缓冲区中的项目数。 The function fills the buffer and then randomly samples from it. Pay attention: this is a pipeline definition , so you'respecifying which operations are in the pipeline, not actually running them!请注意:这是一个管道定义,因此您要重新指定管道中的操作,而不是实际运行它们! The operations actually happen when you call
next(iter(dataset))
, not before.这些操作实际上发生在您调用
next(iter(dataset))
时,而不是之前。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.