简体   繁体   English

TensorFlow 数据集 Shuffle 每个 Epoch

[英]TensorFlow Dataset Shuffle Each Epoch

In the manual on the Dataset class in Tensorflow, it shows how to shuffle the data and how to batch it.在 Tensorflow 中的数据集 class 的手册中,它显示了如何对数据进行混洗以及如何对其进行批处理。 However, it's not apparent how one can shuffle the data each epoch .但是,如何在每个 epoch中对数据进行洗牌还不清楚。 I've tried the below, but the data is given in exactly the same order the second epoch as in the first.我已经尝试过以下方法,但是第二个时期的数据顺序与第一个时期完全相同。 Does anybody know how to shuffle between epochs using a Dataset?有人知道如何使用数据集在不同时期之间进行洗牌吗?

n_epochs = 2
batch_size = 3

data = tf.contrib.data.Dataset.range(12)

data = data.repeat(n_epochs)
data = data.batch(batch_size)
next_batch = data.make_one_shot_iterator().get_next()

sess = tf.Session()
for _ in range(4):
    print(sess.run(next_batch))

print("new epoch")
data = data.shuffle(12)
for _ in range(4):
    print(sess.run(next_batch))

My environment: Python 3.6, TensorFlow 1.4. 我的环境:Python 3.6,TensorFlow 1.4。

TensorFlow has added Dataset into tf.data . TensorFlow已将Dataset添加到tf.data

You should be cautious with the position of data.shuffle . 您应该对data.shuffle的位置保持谨慎。 In your code, the epochs of data has been put into the dataset 's buffer before your shuffle . 在您的代码中,数据的时期已经在您的shuffle之前被放入dataset的缓冲区中。 Here is two usable examples to shuffle dataset. 这是两个可用于混洗数据集的示例。

shuffle all elements 洗牌所有元素

# shuffle all elements
import tensorflow as tf

n_epochs = 2
batch_size = 3
buffer_size = 5

dataset = tf.data.Dataset.range(12)
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(n_epochs)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()

sess = tf.Session()
print("epoch 1")
for _ in range(4):
    print(sess.run(next_batch))
print("epoch 2")
for _ in range(4):
    print(sess.run(next_batch))

OUTPUT: OUTPUT:

epoch 1
[1 4 5]
[3 0 7]
[6 9 8]
[10  2 11]
epoch 2
[2 0 6]
[1 7 4]
[5 3 8]
[11  9 10]

shuffle between batches, not shuffle in a batch 批次之间的混洗,而不是批量洗牌

# shuffle between batches, not shuffle in a batch
import tensorflow as tf

n_epochs = 2
batch_size = 3
buffer_size = 5

dataset = tf.data.Dataset.range(12)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(n_epochs)
dataset = dataset.shuffle(buffer_size=buffer_size)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()

sess = tf.Session()
print("epoch 1")
for _ in range(4):
    print(sess.run(next_batch))
print("epoch 2")
for _ in range(4):
    print(sess.run(next_batch))

OUTPUT: OUTPUT:

epoch 1
[0 1 2]
[6 7 8]
[3 4 5]
[6 7 8]
epoch 2
[3 4 5]
[0 1 2]
[ 9 10 11]
[ 9 10 11]

It appears to me that you are using the same next_batch for both cases. 在我看来,你在两种情况下使用相同的next_batch So, depedening on what you really want, you may need to recreate next_batch before your second call to sess.run such as shown below, otherwise the data = data.shuffle(12) does not have any effect on the next_batch you created earlier in the code. 所以,depedening你真正想要的,你可能需要重新创建next_batch你的第二个呼叫之前sess.run ,如下图所示,否则data = data.shuffle(12)没有对任何影响next_batch你在前面创建编码。

n_epochs = 2
batch_size = 3

data = tf.contrib.data.Dataset.range(12)

data = data.repeat(n_epochs)
data = data.batch(batch_size)
next_batch = data.make_one_shot_iterator().get_next()

sess = tf.Session()
for _ in range(4):
    print(sess.run(next_batch))

print("new epoch")
data = data.shuffle(12)

"""See how I recreate next_batch after the data has been shuffled"""
next_batch = data.make_one_shot_iterator().get_next()
for _ in range(4):
    print(sess.run(next_batch))

Please, let me know if this helps. 请让我知道这可不可以帮你。 Thanks. 谢谢。

Here is a simpler solution that does not need to call repeat :这是一个不需要调用repeat的更简单的解决方案:

dataset = tf.data.Dataset.range(12)
dataset = dataset.shuffle(buffer_size=dataset.cardinality(), reshuffle_each_iteration=True)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM