![](/img/trans.png)
[英]How to slice the output of tf.train.shuffle_batch() using tf.train.range_input_producer()?
[英]How to make input pipeline using tf.RandomShuffleQueue and tf.train.shuffle_batch in tensorflow?
在研究深度學習時,存在一個問題,即由於計算機(cpu,ram等)的問題,該模型無法正常工作。
模型:8層模型
數據:numpy數組(20000、20、20、3)
因此,我嘗試將隊列應用於模型,但是失敗了。
這就是我要做的。
https://www.tensorflow.org/images/AnimatedFileQueues.gif
區別在於我不使用文件名。
我使用numpy數組。
請讓我知道問題是什么以及如何解決。
如果您還有其他參考,請提供網址。
import tensorflow as tf
import numpy as np
N_SAMPLES = 1000; NUM_THREADS = 3
all_data = 10 * np.random.randn(N_SAMPLES, 4) + 1
all_target = np.random.randint(0, 2, size=N_SAMPLES)
queue = tf.RandomShuffleQueue(capacity = 200,
min_after_dequeue = 51,
dtypes=[tf.float32, tf.int32],
shapes=[[4], []],
names=None,
seed=20,
shared_name=None,
name='random_shuffle_queue')
enqueue_op = queue.enqueue_many([all_data, all_target])
data_sample, label_sample = queue.dequeue()
qr = tf.train.QueueRunner(queue, [enqueue_op] * NUM_THREADS)
train_data_batch, train_labels_batch = tf.train.shuffle_batch([data_sample, label_sample],
capacity=200,
min_after_dequeue=51,
enqueue_many=False,
batch_size=50,
seed=20)
with tf.Session() as sess:
# create a coordinator, launch the queue runner threads.
coord = tf.train.Coordinator()
enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
for step in range(10): # do to 10 iterations
if coord.should_stop():
break
one_data, one_label = sess.run([train_data_batch, train_labels_batch])
print(one_label) # I don't know why it doesn't executed
coord.request_stop()
coord.join(enqueue_threads)
在這種情況下,您可以使用tf.train.slice_input_producer
data_sample, label_sample = slice_input_producer(
[all_data, all_target], num_epochs=None,
shuffle=True, seed=None, capacity=1000, shared_name=None, name=None)
train_data_batch, train_labels_batch = tf.train.shuffle_batch([data_sample, label_sample],
capacity=200, min_after_dequeue=51, enqueue_many=False, batch_size=50, seed=20)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.