簡體   English   中英

OutOfRangeError:運行之間的Tensorflow迭代器未重新初始化

[英]OutOfRangeError: tensorflow iterator not reinitializing between runs

我正在使用以下設置通過tensorflow對Inception模型進行微調,並正在批量添加tf.Dataset API。 但是,每次嘗試訓練該模型時(成功檢索任何批次之前),我都會收到OutOfRangeError聲明迭代器已用盡:

Caught OutOfRangeError. Stopping Training. End of sequence
     [[node IteratorGetNext (defined at <ipython-input-8-c768436e70d8>:13)  = IteratorGetNext[output_shapes=[[?,224,224,3], [?,1]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]
with tf.Graph().as_default():

我創建了一個函數,作為get_batch的結果來饋入硬編碼的批處理,並且該函數可以運行並收斂而沒有任何問題,使我相信圖形和會話代碼可以正常工作。 我還測試了get_batch函數以在會話中進行迭代,這不會導致任何錯誤。 我期望的行為是重新開始訓練(尤其是重置筆記本電腦等)會在數據集上產生新的迭代器。

訓練模型的代碼:

with tf.Graph().as_default():

    tf.logging.set_verbosity(tf.logging.INFO)
    images, labels = get_batch(filenames=tf_train_record_path+train_file)
    # Create the model, use the default arg scope to configure the batch norm parameters.
    with slim.arg_scope(inception.inception_v1_arg_scope()):
        logits, ax = inception.inception_v1(images, num_classes=1, is_training=True)

    # Specify the loss function:
    tf.losses.mean_squared_error(labels,logits)
    total_loss = tf.losses.get_total_loss()
    tf.summary.scalar('losses/Total_Loss', total_loss)


     # Specify the optimizer and create the train op:
    optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
    train_op = slim.learning.create_train_op(total_loss, optimizer)

    # Run the training:
    final_loss = slim.learning.train(
        train_op,
        logdir=train_dir,
        init_fn=get_init_fn(),
        number_of_steps=1)

使用數據集獲取批處理的代碼

def get_batch(filenames):
    dataset = tf.data.TFRecordDataset(filenames=filenames)

    dataset = dataset.map(parse)
    dataset = dataset.batch(2)

    iterator = dataset.make_one_shot_iterator()
    data_X, data_y = iterator.get_next()

    return data_X, data_y 

這個先前詢問的問題類似於我遇到的問題,但是,我沒有使用batch_join調用。 如果這不是slim.learning.train,從檢查點還原或作用域的問題,我不是。 任何幫助,將不勝感激!

您的輸入管道看起來還可以。 問題可能出在TFRecords文件損壞上。 您可以嘗試使用隨機數據編寫代碼,或者通過tf.data.Dataset.from_tensor_slices()將圖像用作numpy數組。 另外,您的解析函數可能會導致問題。 嘗試使用sess.run打印圖像/標簽。

我建議使用Estimator API作為train_op。 它更加方便,並且苗條的將很快被棄用。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM