簡體   English   中英

使用 tf.data 批量處理來自多個 TFRecord 文件的順序數據

[英]Batch sequential data coming from multiple TFRecord files with tf.data

讓我們考慮將數據集拆分為多個 TFRecord 文件:

  • 1.tfrecord ,
  • 2.tfrecord ,
  • 等等。

我想生成由來自同一個 TFRecord 文件的連續元素組成的大小為t (比如3 )的序列,我不希望序列具有屬於不同 TFRecord 文件的元素。

例如,如果我們有兩個包含以下數據的 TFRecord 文件:

  • 1.tfrecord : {0, 1, 2, ..., 7}
  • 2.tfrecord : {1000, 1001, 1002, ..., 1007}

沒有任何改組,我想得到以下批次:

  • 第一批: 0, 1, 2 ,
  • 第二批: 1, 2, 3 ,
  • ...
  • 第 i 批: 5, 6, 7 ,
  • (i+1)-th 批次: 1000, 1001, 1002 ,
  • (i+2)-th 批次: 1001, 1002, 1003 ,
  • ...
  • 第 j 批: 1005, 1006, 1007 ,
  • (j+1)-th 批次: 0, 1, 2 ,
  • 等等。

我知道如何使用tf.data.Dataset.windowtf.data.Dataset.batch生成序列數據,但我不知道如何防止序列包含來自不同文件的元素。

我正在尋找一個可擴展的解決方案,即該解決方案應該適用於數百個 TFRecord 文件。

以下是我失敗的嘗試(完全可重現的示例):

import tensorflow as tf

# ****************************
# Generate toy TF Record files

def _create_example(i):
    example = tf.train.Features(feature={'data': tf.train.Feature(int64_list=tf.train.Int64List(value=[i]))})
    return tf.train.Example(features=example)

def parse_fn(serialized_example):
    return tf.parse_single_example(serialized_example, {'data': tf.FixedLenFeature([], tf.int64)})['data']


num_tf_records = 2
records_per_file = 8
options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
for i in range(num_tf_records):
    with tf.python_io.TFRecordWriter('%i.tfrecord' % i, options=options) as writer:
        for j in range(records_per_file):
            example = _create_example(j + 1000 * i)
            writer.write(example.SerializeToString())
# ****************************
# ****************************


data = tf.data.TFRecordDataset(['0.tfrecord', '1.tfrecord'], compression_type='GZIP')\
            .map(lambda x: parse_fn(x))

data = data.window(3, 1, 1, True)\
           .repeat(-1)\
           .flat_map(lambda x: x.batch(3))\
           .batch(16)

data_it = data.make_initializable_iterator()
next_element = data_it.get_next()

with tf.Session() as sess:
    sess.run(data_it.initializer)
    print(sess.run(next_element))

輸出:

[[   0    1    2]   # good
 [   1    2    3]   # good
 [   2    3    4]   # good
 [   3    4    5]   # good
 [   4    5    6]   # good
 [   5    6    7]   # good
 [   6    7 1000]   # bad – mix of elements from 0.tfrecord and 1.tfrecord
 [   7 1000 1001]   # bad
 [1000 1001 1002]   # good
 [1001 1002 1003]   # good
 [1002 1003 1004]   # good
 [1003 1004 1005]   # good
 [1004 1005 1006]   # good
 [1005 1006 1007]   # good
 [   0    1    2]   # good
 [   1    2    3]]  # good

我認為你只需要flat_map你必須制作windo數據集的函數:

def make_dataset_from_filename(filename):
  data = tf.data.TFRecordDataset(filename, compression_type='GZIP')\
           .map(lambda x: parse_fn(x))

  data = data.window(3, 1, 1, True)\
             .repeat(-1)\
             .flat_map(lambda x: x.batch(3))\
             .batch(16)

tf.data.Dataset.list_files('*.tfrecord').flat_map(make_dataset_from_filename)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM