[英]How to properly shuffle a dataset in Tensorflow after every epoch
我目前正在研究一个带有 Tensorflow 和 Keras 的神经网络,我有一个写在 TFRecord 上的数据集,我必须从中读取数据,问题是神经网络经过了体积训练,我没有足够的 ZCD69B4957F06CD8198D7 来存储全部在 ram 中,我正在读取这样的数据,代码取自这两个地方:
https://keras.io/examples/keras_recipes/tfrecord/
def load_dataset(filenames):
option_no_order = tf.data.Options()
option_no_order.experimental_deterministic = False
dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
dataset = dataset.with_options(option_no_order)
dataset = dataset.map(decode_record, num_parallel_calls=AUTO)
return dataset
def get_batched_dataset(filenames, train=False):
dataset = load_dataset(filenames)
if train:
dataset = dataset.repeat() # Best practices for Keras: Training dataset: repeat then batch Evaluation dataset: do not repeat
dataset = dataset.cache() # This dataset fits in RAM
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
return dataset
这段代码有效,但我想在每个 epoch 之后添加训练数据集的改组,我写了这个:
def get_batched_dataset(filenames, train=False):
dataset = load_dataset(filenames)
if train:
dataset = dataset.shuffle(200, reshuffle_each_iteration=True) ###############
dataset = dataset.repeat() # Best practices for Keras: Training dataset: repeat then batch Evaluation dataset: do not repeat
dataset = dataset.cache() # This dataset fits in RAM
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
return dataset
这段代码正在工作,但我观察到每个时期之后 RAM 使用量的增加,在 4 个时期之后,整个 session 因“没有足够的 RAM”而崩溃。
我得到数据集并像这样训练网络:
train_dataset = get_directories()
training_dataset = get_batched_dataset('train.tfrecords', train=True)
validation_dataset = get_batched_dataset('valid.tfrecords', train=False)
model.fit(training_dataset, steps_per_epoch=len(train_dataset), epochs=80, validation_data=validation_dataset, callbacks=my_callbacks)
随机播放 function 需要 200 个卷并将它们放入缓冲区并随机馈送到网络,我不明白为什么 session 会那样崩溃
您可以查看Custom DataGenerators 。
在他们的on_epoch_end
方法中,您可以对每个时期从 memory 加载的数据进行洗牌。
您正在使用 dataset.shuffle(),然后在使用.cache()。 由于您每次都在更改数据顺序,因此 tensorflow 将缓存 memory 中的每个混洗数据集。 这导致同一数据集的多个洗牌副本和 RAM 被填满。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.