简体   繁体   中英

Tensorflow freezes on sess.run call using queues

I am testing the tensoprflow queues system and I have a simple tensorflow program that use queues to process the input. The code of sample program is

import tensorflow as tf
import numpy as np

def model(example_batch):
    dense1 = tf.layers.dense(inputs=example_batch, units=64, activation=tf.nn.relu)
    dense2 = tf.layers.dense(inputs=dense1, units=2)
    return dense2

x_input_data = tf.random_normal([1024, 16], mean=0, stddev=1)
q = tf.FIFOQueue(capacity=1, dtypes=tf.float32, shapes=[1024, 16])

enqueue_op = q.enqueue(x_input_data)
numberOfThreads = 1

qr = tf.train.QueueRunner(q, [enqueue_op] * numberOfThreads)
tf.train.add_queue_runner(qr)
input = q.dequeue()

sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
example_batch = tf.train.batch([input], batch_size=1, num_threads=numberOfThreads, capacity=1, enqueue_many=False)

for step in range(100):
    print ("step")
    #sess.run(input)
    p = sess.run(model(example_batch))

coord.request_stop()
coord.join(threads)
sess.close()

The problem is that it freeze at the first p = sess.run(model2layers(example_batch)) , it stop there indefinitely.

What is wrong in the sample program?

The problem is the order of these lines:

threads = tf.train.start_queue_runners(sess=sess, coord=coord)
example_batch = tf.train.batch([input], batch_size=1, num_threads=numberOfThreads, capacity=1, enqueue_many=False)

The first line calls tf.train.start_queue_runners() to start background threads for filling queues. The second line calls tf.train.batch() , which adds a new queue, which requires an additional background thread to be started to fill that queue, but that thread isn't started, so the program hangs.

The solution is quite simple: reverse the two lines, so that tf.train.start_queue_runners() is called after tf.train.batch() .

example_batch = tf.train.batch([input], batch_size=1, num_threads=numberOfThreads, capacity=1, enqueue_many=False)
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM