簡體   English   中英

使用TensorFlow Dataset API的Epoch計數器

[英]Epoch counter with TensorFlow Dataset API

我正在將TensorFlow代碼從舊隊列接口更改為新的Dataset API 在我的舊代碼中,每次在隊列中訪問和處理新的輸入張量時, tf.Variable都會通過遞增tf.Variable跟蹤紀元數。 我想用新的Dataset API來計算這個時代,但是我在使用它時遇到了一些麻煩。

由於我在預處理階段生成了可變數量的數據項,因此在訓練循環中遞增(Python)計數器並不是一件簡單的事情 - 我需要根據輸入來計算epoch計數。隊列或數據集。

我使用舊的隊列系統模仿了之前的情況,這就是我最終得到的數據集API(簡化示例):

with tf.Graph().as_default():

    data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
    input_tensors = (data,)

    epoch_counter = tf.Variable(initial_value=0.0, dtype=tf.float32,
                                trainable=False)

    def pre_processing_func(data_):
        data_size = tf.constant(0.1, dtype=tf.float32)
        epoch_counter_op = tf.assign_add(epoch_counter, data_size)
        with tf.control_dependencies([epoch_counter_op]):
            # normally I would do data-augmentation here
            results = (tf.expand_dims(data_, axis=0),)
            return tf.data.Dataset.from_tensor_slices(results)

    dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
    dataset = dataset_source.flat_map(pre_processing_func)
    dataset = dataset.repeat()
    # ... do something with 'dataset' and print
    # the value of 'epoch_counter' every once a while

但是,這不起作用。 它崩潰了一個神秘的錯誤信息:

 TypeError: In op 'AssignAdd', input types ([tf.float32, tf.float32])
 are not compatible with expected types ([tf.float32_ref, tf.float32])

仔細檢查表明epoch_counter變量可能根本無法在pre_processing_func中訪問。 它可能生活在不同的圖表中嗎?

知道如何修復上面的例子嗎? 或者如何通過其他方式獲得紀元計數器(帶小數點,例如0.4或2.9)?

TL; DR :用以下內容替換epoch_counter的定義:

epoch_counter = tf.get_variable("epoch_counter", initializer=0.0,
                                trainable=False, use_resource=True)

tf.data.Dataset轉換中使用TensorFlow變量有一些限制。 原則限制是所有變量必須是“資源變量”而不是舊的“參考變量”; 遺憾的是,由於向后兼容性原因, tf.Variable仍會創建“參考變量”。

一般來說,如果可以避免變量,我建議不要在tf.data管道中使用變量。 例如,您可以使用Dataset.range()來定義紀元計數器,然后執行以下操作:

epoch_counter = tf.data.Dataset.range(NUM_EPOCHS)
dataset = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip(
    (pre_processing_func(data), tf.data.Dataset.from_tensors(i).repeat()))

上面的片段將每個值附加一個紀元計數器作為第二個組件。

要添加到@ mrry的最佳答案,如果您想要保留在tf.data管道中並且還希望跟蹤每個時期內的迭代,您可以嘗試下面的解決方案。 如果你有非單位批量大小,我想你必須添加行data = data.batch(bs)

import tensorflow as tf
import itertools

def step_counter(): 
    for i in itertools.count(): yield i

num_examples = 3
num_epochs = 2
num_iters = num_examples * num_epochs

features = tf.data.Dataset.range(num_examples)
labels = tf.data.Dataset.range(num_examples)
data = tf.data.Dataset.zip((features, labels))
data = data.shuffle(num_examples)

step = tf.data.Dataset.from_generator(step_counter, tf.int32)
data = tf.data.Dataset.zip((data, step))

epoch = tf.data.Dataset.range(num_epochs)
data = epoch.flat_map(
    lambda i: tf.data.Dataset.zip(
        (data, tf.data.Dataset.from_tensors(i).repeat())))

data = data.repeat(num_epochs)
it = data.make_one_shot_iterator()
example = it.get_next()

with tf.Session() as sess:
    for _ in range(num_iters):
        ((x, y), st), ep = sess.run(example)
        print(f'step {st} \t epoch {ep} \t x {x} \t y {y}')

打印:

step 0   epoch 0     x 2     y 2
step 1   epoch 0     x 0     y 0
step 2   epoch 0     x 1     y 1
step 0   epoch 1     x 2     y 2
step 1   epoch 1     x 0     y 0
step 2   epoch 1     x 1     y 1

data = data.repeat(num_epochs)導致重復已經為num_epochs重復的數據集(也是epoch計數器)。 可以通過for _ in range(num_iters):替換for _ in range(num_iters):輕松獲得for _ in range(num_iters): for _ in range(num_iters+1):

我將numerica的示例代碼擴展到批處理並替換了itertool部分:

num_examples = 5
num_epochs = 4
batch_size = 2
num_iters = int(num_examples * num_epochs / batch_size)

features = tf.data.Dataset.range(num_examples)
labels = tf.data.Dataset.range(num_examples)

data = tf.data.Dataset.zip((features, labels))
data = data.shuffle(num_examples)

epoch = tf.data.Dataset.range(num_epochs)
data = epoch.flat_map(
    lambda i: tf.data.Dataset.zip((
        data,
        tf.data.Dataset.from_tensors(i).repeat(),
        tf.data.Dataset.range(num_examples)
    ))
)

# to flatten the nested datasets
data = data.map(lambda samples, *cnts: samples+cnts )
data = data.batch(batch_size=batch_size)

it = data.make_one_shot_iterator()
x, y, ep, st = it.get_next()

with tf.Session() as sess:
    for _ in range(num_iters):
        x_, y_, ep_, st_ = sess.run([x, y, ep, st])
        print(f'step {st_}\t epoch {ep_} \t x {x_} \t y {y_}')

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM