繁体   English   中英

使用TensorFlow Dataset API和flat_map的并行线程

[英]Parallel threads with TensorFlow Dataset API and flat_map

我正在将TensorFlow代码从旧队列接口更改为新的Dataset API 使用旧接口,我可以为tf.train.shuffle_batch队列指定num_threads参数。 但是,控制数据集API中线程数量的唯一方法似乎是使用num_parallel_calls参数在map函数中。 但是,我正在使用flat_map函数,它没有这样的参数。

问题 :有没有办法控制flat_map函数的线程/进程数? 或者是否有方法将mapflat_map结合使用并仍然指定并行调用的数量?

请注意,并行运行多个线程至关重要,因为我打算在数据进入队列之前在CPU上运行大量预处理。

GitHub上有两个( 这里这里 )相关的帖子,但我不认为他们回答了这个问题。

这是我用例的最小代码示例:

with tf.Graph().as_default():
    data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
    input_tensors = (data,)

    def pre_processing_func(data_):
        # normally I would do data-augmentation here
        results = (tf.expand_dims(data_, axis=0),)
        return tf.data.Dataset.from_tensor_slices(results)

    dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
    dataset = dataset_source.flat_map(pre_processing_func)
    # do something with 'dataset'

据我所知,目前flat_map不提供并行选项。 鉴于大部分计算是在pre_processing_func完成的,您可以使用的解决方法是并行map调用,然后进行一些缓冲,然后使用带有标识lambda函数的flat_map调用来处理输出的扁平化。

在代码中:

NUM_THREADS = 5
BUFFER_SIZE = 1000

def pre_processing_func(data_):
    # data-augmentation here
    # generate new samples starting from the sample `data_`
    artificial_samples = generate_from_sample(data_)
    return atificial_samples

dataset_source = (tf.data.Dataset.from_tensor_slices(input_tensors).
                  map(pre_processing_func, num_parallel_calls=NUM_THREADS).
                  prefetch(BUFFER_SIZE).
                  flat_map(lambda *x : tf.data.Dataset.from_tensor_slices(x)).
                  shuffle(BUFFER_SIZE)) # my addition, probably necessary though

注意(对我自己和试图理解管道的人):

由于pre_processing_func从初始样本开始生成任意数量的新样本(以形状矩阵(?, 512) ),因此需要使用flat_map调用将所有生成的矩阵转换为包含单个样本的Dataset (因此tf.data.Dataset.from_tensor_slices(x) ,然后将所有这些数据集展平为一个包含单个样本的大Dataset

.shuffle() ,将数据集或生成的样本打包在一起可能是一个好主意。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM