[英]Batch sequential data coming from multiple TFRecord files with tf.data
Let's consider a dataset split into multiple TFRecord files:让我们考虑将数据集拆分为多个 TFRecord 文件:
1.tfrecord
, 1.tfrecord
,2.tfrecord
, 2.tfrecord
, I would like to generate sequences of size t
(say 3
) consisting of consecutive elements from the same TFRecord file, I do not want a sequence to have elements belonging to different TFRecord files.我想生成由来自同一个 TFRecord 文件的连续元素组成的大小为
t
(比如3
)的序列,我不希望序列具有属于不同 TFRecord 文件的元素。
For instance, if we have two TFRecord files containing data like:例如,如果我们有两个包含以下数据的 TFRecord 文件:
1.tfrecord
: {0, 1, 2, ..., 7}
1.tfrecord
: {0, 1, 2, ..., 7}
2.tfrecord
: {1000, 1001, 1002, ..., 1007}
2.tfrecord
: {1000, 1001, 1002, ..., 1007}
without any shuffling, I would like to get the following batches:没有任何改组,我想得到以下批次:
0, 1, 2
,0, 1, 2
,1, 2, 3
,1, 2, 3
,5, 6, 7
,5, 6, 7
,1000, 1001, 1002
, 1000, 1001, 1002
,1001, 1002, 1003
, 1001, 1002, 1003
,1005, 1006, 1007
,1005, 1006, 1007
,0, 1, 2
, 0, 1, 2
, I know how to generate sequence data using tf.data.Dataset.window
or tf.data.Dataset.batch
, but I do not know how to prevent a sequence from containing element from different files.我知道如何使用
tf.data.Dataset.window
或tf.data.Dataset.batch
生成序列数据,但我不知道如何防止序列包含来自不同文件的元素。
I'm looking for a scalable solutions, ie the solution should work with hundred of TFRecord files.我正在寻找一个可扩展的解决方案,即该解决方案应该适用于数百个 TFRecord 文件。
Below is my failed attempt (fully reproducible example):以下是我失败的尝试(完全可重现的示例):
import tensorflow as tf
# ****************************
# Generate toy TF Record files
def _create_example(i):
example = tf.train.Features(feature={'data': tf.train.Feature(int64_list=tf.train.Int64List(value=[i]))})
return tf.train.Example(features=example)
def parse_fn(serialized_example):
return tf.parse_single_example(serialized_example, {'data': tf.FixedLenFeature([], tf.int64)})['data']
num_tf_records = 2
records_per_file = 8
options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
for i in range(num_tf_records):
with tf.python_io.TFRecordWriter('%i.tfrecord' % i, options=options) as writer:
for j in range(records_per_file):
example = _create_example(j + 1000 * i)
writer.write(example.SerializeToString())
# ****************************
# ****************************
data = tf.data.TFRecordDataset(['0.tfrecord', '1.tfrecord'], compression_type='GZIP')\
.map(lambda x: parse_fn(x))
data = data.window(3, 1, 1, True)\
.repeat(-1)\
.flat_map(lambda x: x.batch(3))\
.batch(16)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()
with tf.Session() as sess:
sess.run(data_it.initializer)
print(sess.run(next_element))
which outputs:输出:
[[ 0 1 2] # good
[ 1 2 3] # good
[ 2 3 4] # good
[ 3 4 5] # good
[ 4 5 6] # good
[ 5 6 7] # good
[ 6 7 1000] # bad – mix of elements from 0.tfrecord and 1.tfrecord
[ 7 1000 1001] # bad
[1000 1001 1002] # good
[1001 1002 1003] # good
[1002 1003 1004] # good
[1003 1004 1005] # good
[1004 1005 1006] # good
[1005 1006 1007] # good
[ 0 1 2] # good
[ 1 2 3]] # good
I think you just need to flat_map
that function you have to make the windo
datasets:我认为你只需要
flat_map
你必须制作windo
数据集的函数:
def make_dataset_from_filename(filename):
data = tf.data.TFRecordDataset(filename, compression_type='GZIP')\
.map(lambda x: parse_fn(x))
data = data.window(3, 1, 1, True)\
.repeat(-1)\
.flat_map(lambda x: x.batch(3))\
.batch(16)
tf.data.Dataset.list_files('*.tfrecord').flat_map(make_dataset_from_filename)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.