簡體   English   中英

如何在tensorflow中使用tf.RandomShuffleQueue和tf.train.shuffle_batch制作輸入管道?

[英]How to make input pipeline using tf.RandomShuffleQueue and tf.train.shuffle_batch in tensorflow?

在研究深度學習時,存在一個問題,即由於計算機(cpu,ram等)的問題,該模型無法正常工作。

模型:8層模型
數據:numpy數組(20000、20、20、3)

因此,我嘗試將隊列應用於模型,但是失敗了。

這就是我要做的。
https://www.tensorflow.org/images/AnimatedFileQueues.gif

區別在於我不使用文件名。
我使用numpy數組。

請讓我知道問題是什么以及如何解決。

如果您還有其他參考,請提供網址。

import tensorflow as tf
import numpy as np

N_SAMPLES = 1000; NUM_THREADS = 3
all_data = 10 * np.random.randn(N_SAMPLES, 4) + 1
all_target = np.random.randint(0, 2, size=N_SAMPLES)

queue = tf.RandomShuffleQueue(capacity = 200, 
                              min_after_dequeue = 51,
                              dtypes=[tf.float32, tf.int32], 
                              shapes=[[4], []], 
                              names=None, 
                              seed=20, 
                              shared_name=None, 
                              name='random_shuffle_queue')
enqueue_op = queue.enqueue_many([all_data, all_target])
data_sample, label_sample = queue.dequeue()
qr = tf.train.QueueRunner(queue, [enqueue_op] * NUM_THREADS)

train_data_batch, train_labels_batch = tf.train.shuffle_batch([data_sample, label_sample],
                                                              capacity=200, 
                                                              min_after_dequeue=51, 
                                                              enqueue_many=False, 
                                                              batch_size=50, 
                                                              seed=20)

with tf.Session() as sess:
    # create a coordinator, launch the queue runner threads.
    coord = tf.train.Coordinator()
    enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
    for step in range(10): # do to 10 iterations
        if coord.should_stop():
            break
        one_data, one_label = sess.run([train_data_batch, train_labels_batch])
        print(one_label) # I don't know why it doesn't executed
    coord.request_stop()
    coord.join(enqueue_threads)

在這種情況下,您可以使用tf.train.slice_input_producer

data_sample, label_sample = slice_input_producer(
    [all_data, all_target], num_epochs=None,
    shuffle=True, seed=None, capacity=1000, shared_name=None, name=None)
train_data_batch, train_labels_batch = tf.train.shuffle_batch([data_sample, label_sample],
    capacity=200, min_after_dequeue=51, enqueue_many=False, batch_size=50, seed=20)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM