繁体   English   中英

使用 tf.data 批量处理来自多个 TFRecord 文件的顺序数据

[英]Batch sequential data coming from multiple TFRecord files with tf.data

让我们考虑将数据集拆分为多个 TFRecord 文件:

  • 1.tfrecord ,
  • 2.tfrecord ,
  • 等等。

我想生成由来自同一个 TFRecord 文件的连续元素组成的大小为t (比如3 )的序列,我不希望序列具有属于不同 TFRecord 文件的元素。

例如,如果我们有两个包含以下数据的 TFRecord 文件:

  • 1.tfrecord : {0, 1, 2, ..., 7}
  • 2.tfrecord : {1000, 1001, 1002, ..., 1007}

没有任何改组,我想得到以下批次:

  • 第一批: 0, 1, 2 ,
  • 第二批: 1, 2, 3 ,
  • ...
  • 第 i 批: 5, 6, 7 ,
  • (i+1)-th 批次: 1000, 1001, 1002 ,
  • (i+2)-th 批次: 1001, 1002, 1003 ,
  • ...
  • 第 j 批: 1005, 1006, 1007 ,
  • (j+1)-th 批次: 0, 1, 2 ,
  • 等等。

我知道如何使用tf.data.Dataset.windowtf.data.Dataset.batch生成序列数据,但我不知道如何防止序列包含来自不同文件的元素。

我正在寻找一个可扩展的解决方案,即该解决方案应该适用于数百个 TFRecord 文件。

以下是我失败的尝试(完全可重现的示例):

import tensorflow as tf

# ****************************
# Generate toy TF Record files

def _create_example(i):
    example = tf.train.Features(feature={'data': tf.train.Feature(int64_list=tf.train.Int64List(value=[i]))})
    return tf.train.Example(features=example)

def parse_fn(serialized_example):
    return tf.parse_single_example(serialized_example, {'data': tf.FixedLenFeature([], tf.int64)})['data']


num_tf_records = 2
records_per_file = 8
options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
for i in range(num_tf_records):
    with tf.python_io.TFRecordWriter('%i.tfrecord' % i, options=options) as writer:
        for j in range(records_per_file):
            example = _create_example(j + 1000 * i)
            writer.write(example.SerializeToString())
# ****************************
# ****************************


data = tf.data.TFRecordDataset(['0.tfrecord', '1.tfrecord'], compression_type='GZIP')\
            .map(lambda x: parse_fn(x))

data = data.window(3, 1, 1, True)\
           .repeat(-1)\
           .flat_map(lambda x: x.batch(3))\
           .batch(16)

data_it = data.make_initializable_iterator()
next_element = data_it.get_next()

with tf.Session() as sess:
    sess.run(data_it.initializer)
    print(sess.run(next_element))

输出:

[[   0    1    2]   # good
 [   1    2    3]   # good
 [   2    3    4]   # good
 [   3    4    5]   # good
 [   4    5    6]   # good
 [   5    6    7]   # good
 [   6    7 1000]   # bad – mix of elements from 0.tfrecord and 1.tfrecord
 [   7 1000 1001]   # bad
 [1000 1001 1002]   # good
 [1001 1002 1003]   # good
 [1002 1003 1004]   # good
 [1003 1004 1005]   # good
 [1004 1005 1006]   # good
 [1005 1006 1007]   # good
 [   0    1    2]   # good
 [   1    2    3]]  # good

我认为你只需要flat_map你必须制作windo数据集的函数:

def make_dataset_from_filename(filename):
  data = tf.data.TFRecordDataset(filename, compression_type='GZIP')\
           .map(lambda x: parse_fn(x))

  data = data.window(3, 1, 1, True)\
             .repeat(-1)\
             .flat_map(lambda x: x.batch(3))\
             .batch(16)

tf.data.Dataset.list_files('*.tfrecord').flat_map(make_dataset_from_filename)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM