简体   繁体   中英

Tensorflow input pipeline for distributed training

I am trying to figure out how to setup my input pipeline for tensorflow in distributed training. It's not clear whether the readers will read from a single process and send the data to all workers or each server will start it's own input pipeline? How do we ensure that every worker has a different input going to it?

I will give an example of how I do it:

import tensorflow as tf
batch_size = 50
task_index = 2
num_workers = 10
input_pattern = "gs://backet/dir/part-00*"

get all names of files in the bucket that correspond to input_pattern

files_names = tf.train.match_filenames_once(
                input_pattern, name = "myFiles")

select names for worker task_index . tf.strided_slice is like slice for lists: a[::,task_index] (select every task_index th file for worker task_index )

to_process = tf.strided_slice(files_names, [task_index],
                 [999999999], strides=[num_workers])
filename_queue = tf.train.string_input_producer(to_process,
                     shuffle=True, #shufle files
                     num_epochs=num_epochs)

reader = tf.TextLineReader()
_ , value = reader.read(filename_queue)
col1,col2 = tf.decode_csv(value,
        record_defaults=[[1],[1]], field_delim="\t")

train_inputs, train_labels = tf.train.shuffle_batch([col1,[col2]],
        batch_size=batch_size,
        capacity=50*batch_size,
        num_threads=10,
        min_after_dequeue = 10*batch_size,
        allow_smaller_final_batch = True)

loss = f(...,train_inputs, train_labels)
optimizer = ...

with tf.train.MonitoredTrainingSession(...) as mon_sess:
    coord = tf.train.Coordinator()
    with coord.stop_on_exception():
        _ = tf.train.start_queue_runners(sess = mon_sess, coord=coord)
        while not coord.should_stop() and not mon_sess.should_stop():
            optimizer.run()

I'm not sure my method is the best way to implement input pipeline in case of distributed TensorFlow implementation because each worker reads names of all files in the bucket


Good lecture about input pipeline in TensorFlow: http://web.stanford.edu/class/cs20si/lectures/notes_09.pdf

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM