[英]How to preserve dict keys in tf.data.Dataset.from_generator?
[英]How to speed up tf.data.Dataset.from_generator()
在 tensorflow2.0 中,我想训练一个具有 nce 损失的 skip-gram 模型。 tf.data.Dataset.from_tensor_slices() 不适合,因为输入文件真的很大。 所以我写了一个这样的数据集生成器类:
class DataSet:
""""""
def __init__(self, args, vocab):
self.args = args
self.vocab = vocab
def generator(self):
"""a generator function, it will return skip-gram sample or cbow sample"""
with open(self.args.input) as f_input:
for line in tqdm.tqdm(f_input.readlines()):
tokens = line.strip().split()
tokens_indices = self.vocab.indices(tokens)
for index, target_word in enumerate(tokens_indices):
context_words = list()
begin = index - self.args.window_size if index - self.args.window_size > 0 else 0
end = index + 1 + self.args.window_size if index + self.args.window_size + 1 < len(tokens_indices) else len(
tokens_indices)
context_words.extend(tokens_indices[begin:index])
context_words.extend(tokens_indices[index + 1:end])
if self.args.cbow > 0:
yield context_words, target_word
else:
for i in range(len(context_words)):
yield target_word, context_words[i]
def dataset(self):
"""Using tf.data.Dataset.from_generator() to return sample"""
if self.args.cbow:
dataset = tf.data.Dataset.from_generator(
self.generator,
(tf.int32, tf.int32),
(tf.TensorShape([None]), tf.TensorShape([]))
)
else:
dataset = tf.data.Dataset.from_generator(
self.generator,
(tf.int32, tf.int32),
(tf.TensorShape([]), tf.TensorShape([]))
)
return dataset
然后我用以下方法测试我的代码:
dataset = DataSet(args, vocab).dataset()
iterator = dataset.make_one_shot_iterator()
for batch, (x,y) in enumerate(dataset.batch(128)):
pass
print(batch, x.shape, y.shape)
但是迭代所有行需要花费大量时间(在 MacBook pro 2012 中大约 10 分钟/15000 行)。 有没有什么方法可以加速代码?
如果您正在处理大型数据集,那么 TFRecord 是合适的选择。 它使用二进制文件格式来存储您的数据,并且会对导入管道的性能产生重大影响,从而对模型的训练时间产生重大影响。 二进制数据占用更少的磁盘空间,复制所需的时间更少,并且可以更有效地从磁盘读取。 如果您的数据存储在旋转磁盘上,则尤其如此,因为与 SSD 相比读/写性能要低得多。
然而,纯粹的性能并不是 TFRecord 文件格式的唯一优势。 它以多种方式针对 Tensorflow 进行了优化。 首先,它可以轻松组合多个数据集,并与库提供的数据导入和预处理功能无缝集成。 特别是对于太大而无法完全存储在内存中的数据集,这是一个优势,因为只有当时需要的数据(例如批处理)从磁盘加载然后进行处理。 TFRecords 的另一个主要优点是可以存储序列数据——例如,时间序列或单词编码——以允许非常有效和(从编码角度)方便地导入此类数据的方式。
建议通过此官方链接查看TFRecord 。 您也可以通过此链接了解如何构建 TFRecord 管道。
下面是一个简单的例子,使用TFRecordWriter
写入序列化记录,然后将其加载到TFRecordDatset
%tensorflow_version 2.x
import tensorflow as tf
print(tf.__version__)
def write_date_tfrecord():
#writes 10 dummy values to replicate the issue
Output = [20191221 + x for x in range(0,10)]
print("Writing Output - ", Output)
example = tf.train.Example(
features = tf.train.Features(
feature = {
'Output':tf.train.Feature(float_list=tf.train.FloatList(value=Output))
}
))
writer = tf.io.TFRecordWriter("Output.tf_record")
writer.write(example.SerializeToString())
def parse_function(serialized_example):
features = {
'Output': tf.io.FixedLenSequenceFeature([], tf.float32,allow_missing=True)
}
features = tf.io.parse_single_example(serialized=serialized_example, features=features)
Output = features['Output']
return Output
def dataset_generator():
trRecordDataset = tf.data.TFRecordDataset("Output.tf_record")
trRecordDataset = trRecordDataset.map(parse_function, num_parallel_calls = tf.data.experimental.AUTOTUNE)
return trRecordDataset
if __name__ == '__main__':
write_date_tfrecord()
generator = dataset_generator()
for Output in generator:
print(Output)
输出 -
2.2.0
Writing Output - [20191221, 20191222, 20191223, 20191224, 20191225, 20191226, 20191227, 20191228, 20191229, 20191230]
tf.Tensor(
[20191220. 20191222. 20191224. 20191224. 20191224. 20191226. 20191228.
20191228. 20191228. 20191230.], shape=(10,), dtype=float32)
希望这能回答你的问题。 快乐学习。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.