[英]Batch sequential data coming from multiple TFRecord files with tf.data
讓我們考慮將數據集拆分為多個 TFRecord 文件:
1.tfrecord
,2.tfrecord
, 我想生成由來自同一個 TFRecord 文件的連續元素組成的大小為t
(比如3
)的序列,我不希望序列具有屬於不同 TFRecord 文件的元素。
例如,如果我們有兩個包含以下數據的 TFRecord 文件:
1.tfrecord
: {0, 1, 2, ..., 7}
2.tfrecord
: {1000, 1001, 1002, ..., 1007}
沒有任何改組,我想得到以下批次:
0, 1, 2
,1, 2, 3
,5, 6, 7
,1000, 1001, 1002
,1001, 1002, 1003
,1005, 1006, 1007
,0, 1, 2
, 我知道如何使用tf.data.Dataset.window
或tf.data.Dataset.batch
生成序列數據,但我不知道如何防止序列包含來自不同文件的元素。
我正在尋找一個可擴展的解決方案,即該解決方案應該適用於數百個 TFRecord 文件。
以下是我失敗的嘗試(完全可重現的示例):
import tensorflow as tf
# ****************************
# Generate toy TF Record files
def _create_example(i):
example = tf.train.Features(feature={'data': tf.train.Feature(int64_list=tf.train.Int64List(value=[i]))})
return tf.train.Example(features=example)
def parse_fn(serialized_example):
return tf.parse_single_example(serialized_example, {'data': tf.FixedLenFeature([], tf.int64)})['data']
num_tf_records = 2
records_per_file = 8
options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
for i in range(num_tf_records):
with tf.python_io.TFRecordWriter('%i.tfrecord' % i, options=options) as writer:
for j in range(records_per_file):
example = _create_example(j + 1000 * i)
writer.write(example.SerializeToString())
# ****************************
# ****************************
data = tf.data.TFRecordDataset(['0.tfrecord', '1.tfrecord'], compression_type='GZIP')\
.map(lambda x: parse_fn(x))
data = data.window(3, 1, 1, True)\
.repeat(-1)\
.flat_map(lambda x: x.batch(3))\
.batch(16)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()
with tf.Session() as sess:
sess.run(data_it.initializer)
print(sess.run(next_element))
輸出:
[[ 0 1 2] # good
[ 1 2 3] # good
[ 2 3 4] # good
[ 3 4 5] # good
[ 4 5 6] # good
[ 5 6 7] # good
[ 6 7 1000] # bad – mix of elements from 0.tfrecord and 1.tfrecord
[ 7 1000 1001] # bad
[1000 1001 1002] # good
[1001 1002 1003] # good
[1002 1003 1004] # good
[1003 1004 1005] # good
[1004 1005 1006] # good
[1005 1006 1007] # good
[ 0 1 2] # good
[ 1 2 3]] # good
我認為你只需要flat_map
你必須制作windo
數據集的函數:
def make_dataset_from_filename(filename):
data = tf.data.TFRecordDataset(filename, compression_type='GZIP')\
.map(lambda x: parse_fn(x))
data = data.window(3, 1, 1, True)\
.repeat(-1)\
.flat_map(lambda x: x.batch(3))\
.batch(16)
tf.data.Dataset.list_files('*.tfrecord').flat_map(make_dataset_from_filename)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.