![](/img/trans.png)
[英]Tensorflow dataset from generator OutOfRangeError: End of sequence
[英]OutOfRangeError: tensorflow iterator not reinitializing between runs
我正在使用以下设置通过tensorflow对Inception模型进行微调,并正在批量添加tf.Dataset
API。 但是,每次尝试训练该模型时(成功检索任何批次之前),我都会收到OutOfRangeError声明迭代器已用尽:
Caught OutOfRangeError. Stopping Training. End of sequence
[[node IteratorGetNext (defined at <ipython-input-8-c768436e70d8>:13) = IteratorGetNext[output_shapes=[[?,224,224,3], [?,1]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]
with tf.Graph().as_default():
我创建了一个函数,作为get_batch
的结果来馈入硬编码的批处理,并且该函数可以运行并收敛而没有任何问题,使我相信图形和会话代码可以正常工作。 我还测试了get_batch
函数以在会话中进行迭代,这不会导致任何错误。 我期望的行为是重新开始训练(尤其是重置笔记本电脑等)会在数据集上产生新的迭代器。
训练模型的代码:
with tf.Graph().as_default():
tf.logging.set_verbosity(tf.logging.INFO)
images, labels = get_batch(filenames=tf_train_record_path+train_file)
# Create the model, use the default arg scope to configure the batch norm parameters.
with slim.arg_scope(inception.inception_v1_arg_scope()):
logits, ax = inception.inception_v1(images, num_classes=1, is_training=True)
# Specify the loss function:
tf.losses.mean_squared_error(labels,logits)
total_loss = tf.losses.get_total_loss()
tf.summary.scalar('losses/Total_Loss', total_loss)
# Specify the optimizer and create the train op:
optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
train_op = slim.learning.create_train_op(total_loss, optimizer)
# Run the training:
final_loss = slim.learning.train(
train_op,
logdir=train_dir,
init_fn=get_init_fn(),
number_of_steps=1)
使用数据集获取批处理的代码
def get_batch(filenames):
dataset = tf.data.TFRecordDataset(filenames=filenames)
dataset = dataset.map(parse)
dataset = dataset.batch(2)
iterator = dataset.make_one_shot_iterator()
data_X, data_y = iterator.get_next()
return data_X, data_y
这个先前询问的问题类似于我遇到的问题,但是,我没有使用batch_join
调用。 如果这不是slim.learning.train,从检查点还原或作用域的问题,我不是。 任何帮助,将不胜感激!
您的输入管道看起来还可以。 问题可能出在TFRecords文件损坏上。 您可以尝试使用随机数据编写代码,或者通过tf.data.Dataset.from_tensor_slices()
将图像用作numpy数组。 另外,您的解析函数可能会导致问题。 尝试使用sess.run
打印图像/标签。
我建议使用Estimator API作为train_op。 它更加方便,并且苗条的将很快被弃用。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.