[英]Epoch counter with TensorFlow Dataset API
我正在将TensorFlow代码从旧队列接口更改为新的Dataset API 。 在我的旧代码中,每次在队列中访问和处理新的输入张量时, tf.Variable
都会通过递增tf.Variable
跟踪纪元数。 我想用新的Dataset API来计算这个时代,但是我在使用它时遇到了一些麻烦。
由于我在预处理阶段生成了可变数量的数据项,因此在训练循环中递增(Python)计数器并不是一件简单的事情 - 我需要根据输入来计算epoch计数。队列或数据集。
我使用旧的队列系统模仿了之前的情况,这就是我最终得到的数据集API(简化示例):
with tf.Graph().as_default():
data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
input_tensors = (data,)
epoch_counter = tf.Variable(initial_value=0.0, dtype=tf.float32,
trainable=False)
def pre_processing_func(data_):
data_size = tf.constant(0.1, dtype=tf.float32)
epoch_counter_op = tf.assign_add(epoch_counter, data_size)
with tf.control_dependencies([epoch_counter_op]):
# normally I would do data-augmentation here
results = (tf.expand_dims(data_, axis=0),)
return tf.data.Dataset.from_tensor_slices(results)
dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
dataset = dataset_source.flat_map(pre_processing_func)
dataset = dataset.repeat()
# ... do something with 'dataset' and print
# the value of 'epoch_counter' every once a while
但是,这不起作用。 它崩溃了一个神秘的错误信息:
TypeError: In op 'AssignAdd', input types ([tf.float32, tf.float32])
are not compatible with expected types ([tf.float32_ref, tf.float32])
仔细检查表明epoch_counter
变量可能根本无法在pre_processing_func
中访问。 它可能生活在不同的图表中吗?
知道如何修复上面的例子吗? 或者如何通过其他方式获得纪元计数器(带小数点,例如0.4或2.9)?
TL; DR :用以下内容替换epoch_counter
的定义:
epoch_counter = tf.get_variable("epoch_counter", initializer=0.0,
trainable=False, use_resource=True)
在tf.data.Dataset
转换中使用TensorFlow变量有一些限制。 原则限制是所有变量必须是“资源变量”而不是旧的“参考变量”; 遗憾的是,由于向后兼容性原因, tf.Variable
仍会创建“参考变量”。
一般来说,如果可以避免变量,我建议不要在tf.data
管道中使用变量。 例如,您可以使用Dataset.range()
来定义纪元计数器,然后执行以下操作:
epoch_counter = tf.data.Dataset.range(NUM_EPOCHS)
dataset = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip(
(pre_processing_func(data), tf.data.Dataset.from_tensors(i).repeat()))
上面的片段将每个值附加一个纪元计数器作为第二个组件。
要添加到@ mrry的最佳答案,如果您想要保留在tf.data
管道中并且还希望跟踪每个时期内的迭代,您可以尝试下面的解决方案。 如果你有非单位批量大小,我想你必须添加行data = data.batch(bs)
。
import tensorflow as tf
import itertools
def step_counter():
for i in itertools.count(): yield i
num_examples = 3
num_epochs = 2
num_iters = num_examples * num_epochs
features = tf.data.Dataset.range(num_examples)
labels = tf.data.Dataset.range(num_examples)
data = tf.data.Dataset.zip((features, labels))
data = data.shuffle(num_examples)
step = tf.data.Dataset.from_generator(step_counter, tf.int32)
data = tf.data.Dataset.zip((data, step))
epoch = tf.data.Dataset.range(num_epochs)
data = epoch.flat_map(
lambda i: tf.data.Dataset.zip(
(data, tf.data.Dataset.from_tensors(i).repeat())))
data = data.repeat(num_epochs)
it = data.make_one_shot_iterator()
example = it.get_next()
with tf.Session() as sess:
for _ in range(num_iters):
((x, y), st), ep = sess.run(example)
print(f'step {st} \t epoch {ep} \t x {x} \t y {y}')
打印:
step 0 epoch 0 x 2 y 2
step 1 epoch 0 x 0 y 0
step 2 epoch 0 x 1 y 1
step 0 epoch 1 x 2 y 2
step 1 epoch 1 x 0 y 0
step 2 epoch 1 x 1 y 1
行data = data.repeat(num_epochs)
导致重复已经为num_epochs重复的数据集(也是epoch计数器)。 可以通过for _ in range(num_iters):
替换for _ in range(num_iters):
轻松获得for _ in range(num_iters):
for _ in range(num_iters+1):
我将numerica的示例代码扩展到批处理并替换了itertool
部分:
num_examples = 5
num_epochs = 4
batch_size = 2
num_iters = int(num_examples * num_epochs / batch_size)
features = tf.data.Dataset.range(num_examples)
labels = tf.data.Dataset.range(num_examples)
data = tf.data.Dataset.zip((features, labels))
data = data.shuffle(num_examples)
epoch = tf.data.Dataset.range(num_epochs)
data = epoch.flat_map(
lambda i: tf.data.Dataset.zip((
data,
tf.data.Dataset.from_tensors(i).repeat(),
tf.data.Dataset.range(num_examples)
))
)
# to flatten the nested datasets
data = data.map(lambda samples, *cnts: samples+cnts )
data = data.batch(batch_size=batch_size)
it = data.make_one_shot_iterator()
x, y, ep, st = it.get_next()
with tf.Session() as sess:
for _ in range(num_iters):
x_, y_, ep_, st_ = sess.run([x, y, ep, st])
print(f'step {st_}\t epoch {ep_} \t x {x_} \t y {y_}')
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.