繁体   English   中英

如何在每个纪元后正确洗牌 Tensorflow 中的数据集

[英]How to properly shuffle a dataset in Tensorflow after every epoch

我目前正在研究一个带有 Tensorflow 和 Keras 的神经网络,我有一个写在 TFRecord 上的数据集,我必须从中读取数据,问题是神经网络经过了体积训练,我没有足够的 ZCD69B4957F06CD8198D7 来存储全部在 ram 中,我正在读取这样的数据,代码取自这两个地方:

https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_solution.ipynb

https://keras.io/examples/keras_recipes/tfrecord/

def load_dataset(filenames):
  option_no_order = tf.data.Options()
  option_no_order.experimental_deterministic = False

  dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
  dataset = dataset.with_options(option_no_order)
  dataset = dataset.map(decode_record, num_parallel_calls=AUTO)
  return dataset

def get_batched_dataset(filenames, train=False):
dataset = load_dataset(filenames)
if train:
  dataset = dataset.repeat() # Best practices for Keras: Training dataset: repeat then batch Evaluation dataset: do not repeat
dataset = dataset.cache() # This dataset fits in RAM
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
return dataset

这段代码有效,但我想在每个 epoch 之后添加训练数据集的改组,我写了这个:

def get_batched_dataset(filenames, train=False):
dataset = load_dataset(filenames)
if train:
  dataset = dataset.shuffle(200, reshuffle_each_iteration=True) ###############
  dataset = dataset.repeat() # Best practices for Keras: Training dataset: repeat then batch Evaluation dataset: do not repeat
dataset = dataset.cache() # This dataset fits in RAM
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
return dataset

这段代码正在工作,但我观察到每个时期之后 RAM 使用量的增加,在 4 个时期之后,整个 session 因“没有足够的 RAM”而崩溃。

我得到数据集并像这样训练网络:

train_dataset = get_directories()
training_dataset = get_batched_dataset('train.tfrecords', train=True)
validation_dataset = get_batched_dataset('valid.tfrecords', train=False)
model.fit(training_dataset, steps_per_epoch=len(train_dataset), epochs=80, validation_data=validation_dataset, callbacks=my_callbacks)

随机播放 function 需要 200 个卷并将它们放入缓冲区并随机馈送到网络,我不明白为什么 session 会那样崩溃

您可以查看Custom DataGenerators

在他们的on_epoch_end方法中,您可以对每个时期从 memory 加载的数据进行洗牌。

您正在使用 dataset.shuffle(),然后在使用.cache()。 由于您每次都在更改数据顺序,因此 tensorflow 将缓存 memory 中的每个混洗数据集。 这导致同一数据集的多个洗牌副本和 RAM 被填满。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM