簡體   English   中英

TensorFlow 數據集 Shuffle 每個 Epoch

[英]TensorFlow Dataset Shuffle Each Epoch

在 Tensorflow 中的數據集 class 的手冊中,它顯示了如何對數據進行混洗以及如何對其進行批處理。 但是,如何在每個 epoch中對數據進行洗牌還不清楚。 我已經嘗試過以下方法,但是第二個時期的數據順序與第一個時期完全相同。 有人知道如何使用數據集在不同時期之間進行洗牌嗎?

n_epochs = 2
batch_size = 3

data = tf.contrib.data.Dataset.range(12)

data = data.repeat(n_epochs)
data = data.batch(batch_size)
next_batch = data.make_one_shot_iterator().get_next()

sess = tf.Session()
for _ in range(4):
    print(sess.run(next_batch))

print("new epoch")
data = data.shuffle(12)
for _ in range(4):
    print(sess.run(next_batch))

我的環境:Python 3.6,TensorFlow 1.4。

TensorFlow已將Dataset添加到tf.data

您應該對data.shuffle的位置保持謹慎。 在您的代碼中,數據的時期已經在您的shuffle之前被放入dataset的緩沖區中。 這是兩個可用於混洗數據集的示例。

洗牌所有元素

# shuffle all elements
import tensorflow as tf

n_epochs = 2
batch_size = 3
buffer_size = 5

dataset = tf.data.Dataset.range(12)
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(n_epochs)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()

sess = tf.Session()
print("epoch 1")
for _ in range(4):
    print(sess.run(next_batch))
print("epoch 2")
for _ in range(4):
    print(sess.run(next_batch))

OUTPUT:

epoch 1
[1 4 5]
[3 0 7]
[6 9 8]
[10  2 11]
epoch 2
[2 0 6]
[1 7 4]
[5 3 8]
[11  9 10]

批次之間的混洗,而不是批量洗牌

# shuffle between batches, not shuffle in a batch
import tensorflow as tf

n_epochs = 2
batch_size = 3
buffer_size = 5

dataset = tf.data.Dataset.range(12)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(n_epochs)
dataset = dataset.shuffle(buffer_size=buffer_size)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()

sess = tf.Session()
print("epoch 1")
for _ in range(4):
    print(sess.run(next_batch))
print("epoch 2")
for _ in range(4):
    print(sess.run(next_batch))

OUTPUT:

epoch 1
[0 1 2]
[6 7 8]
[3 4 5]
[6 7 8]
epoch 2
[3 4 5]
[0 1 2]
[ 9 10 11]
[ 9 10 11]

在我看來,你在兩種情況下使用相同的next_batch 所以,depedening你真正想要的,你可能需要重新創建next_batch你的第二個呼叫之前sess.run ,如下圖所示,否則data = data.shuffle(12)沒有對任何影響next_batch你在前面創建編碼。

n_epochs = 2
batch_size = 3

data = tf.contrib.data.Dataset.range(12)

data = data.repeat(n_epochs)
data = data.batch(batch_size)
next_batch = data.make_one_shot_iterator().get_next()

sess = tf.Session()
for _ in range(4):
    print(sess.run(next_batch))

print("new epoch")
data = data.shuffle(12)

"""See how I recreate next_batch after the data has been shuffled"""
next_batch = data.make_one_shot_iterator().get_next()
for _ in range(4):
    print(sess.run(next_batch))

請讓我知道這可不可以幫你。 謝謝。

這是一個不需要調用repeat的更簡單的解決方案:

dataset = tf.data.Dataset.range(12)
dataset = dataset.shuffle(buffer_size=dataset.cardinality(), reshuffle_each_iteration=True)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM