简体   繁体   中英

How to speed up tf.data.Dataset.from_generator()

In tensorflow2.0, I want to train a skip-gram model with nce loss. tf.data.Dataset.from_tensor_slices() is not suitable because the input file is really huge. So I write a dataset generator class like this:

class DataSet:
    """"""

    def __init__(self, args, vocab):
        self.args = args
        self.vocab = vocab

    def generator(self):
        """a generator function, it will return skip-gram sample or cbow sample"""
        with open(self.args.input) as f_input:
            for line in tqdm.tqdm(f_input.readlines()):
                tokens = line.strip().split()
                tokens_indices = self.vocab.indices(tokens)
                for index, target_word in enumerate(tokens_indices):
                    context_words = list()
                    begin = index - self.args.window_size if index - self.args.window_size > 0 else 0
                    end = index + 1 + self.args.window_size if index + self.args.window_size + 1 < len(tokens_indices) else len(
                        tokens_indices)
                    context_words.extend(tokens_indices[begin:index])
                    context_words.extend(tokens_indices[index + 1:end])
                    if self.args.cbow > 0:
                        yield context_words, target_word
                    else:
                        for i in range(len(context_words)):
                            yield target_word, context_words[i]

    def dataset(self):
        """Using tf.data.Dataset.from_generator() to return sample"""
        if self.args.cbow:
            dataset = tf.data.Dataset.from_generator(
                self.generator,
                (tf.int32, tf.int32),
                (tf.TensorShape([None]), tf.TensorShape([]))
            )
        else:
            dataset = tf.data.Dataset.from_generator(
                self.generator,
                (tf.int32, tf.int32),
                (tf.TensorShape([]), tf.TensorShape([]))
            )

        return dataset

Then I test my code with follow:

dataset = DataSet(args, vocab).dataset()
iterator = dataset.make_one_shot_iterator()
for batch, (x,y) in enumerate(dataset.batch(128)):
    pass
print(batch, x.shape, y.shape)

But it cost a lot of time to iterate all lines(about 10 minutes / 15000 lines in MacBook pro 2012). Does there any methods can speed up the code?

If you are working with large datasets then TFRecord is suitable option. It uses a binary file format for storage of your data and can have a significant impact on the performance of your import pipeline and as a consequence on the training time of your model. Binary data takes up less space on disk, takes less time to copy and can be read much more efficiently from disk. This is especially true if your data is stored on spinning disks, due to the much lower read/write performance in comparison with SSDs.

However, pure performance isn't the only advantage of the TFRecord file format. It is optimized for use with Tensorflow in multiple ways. To start with, it makes it easy to combine multiple datasets and integrates seamlessly with the data import and preprocessing functionality provided by the library. Especially for datasets that are too large to be stored fully in memory this is an advantage as only the data that is required at the time (eg a batch) is loaded from disk and then processed. Another major advantage of TFRecords is that it is possible to store sequence data — for instance, a time series or word encodings — in a way that allows for very efficient and (from a coding perspective) convenient import of this type of data.

Would recommend to go through this official link for glimpse on TFRecord . Also you can go through this link on how to build TFRecord pipeline .

Here is a simple example of writing the serialized record using TFRecordWriter and then loading it in TFRecordDatset

%tensorflow_version 2.x
import tensorflow as tf
print(tf.__version__)

def write_date_tfrecord():  
    #writes 10 dummy values to replicate the issue
    Output = [20191221 + x for x in range(0,10)]
    print("Writing Output - ", Output)

    example = tf.train.Example(
            features = tf.train.Features(
                feature = {                    
                    'Output':tf.train.Feature(float_list=tf.train.FloatList(value=Output))                    
                     }
                ))


    writer = tf.io.TFRecordWriter("Output.tf_record")
    writer.write(example.SerializeToString())

def parse_function(serialized_example):
        features = {
            'Output': tf.io.FixedLenSequenceFeature([], tf.float32,allow_missing=True) 
             }
        features = tf.io.parse_single_example(serialized=serialized_example, features=features)  
        Output = features['Output']
        return Output

def dataset_generator():
    trRecordDataset = tf.data.TFRecordDataset("Output.tf_record")
    trRecordDataset = trRecordDataset.map(parse_function, num_parallel_calls = tf.data.experimental.AUTOTUNE)
    return trRecordDataset

if __name__ == '__main__':
    write_date_tfrecord()
    generator = dataset_generator()
    for Output in generator:
        print(Output)

Output -

2.2.0
Writing Output -  [20191221, 20191222, 20191223, 20191224, 20191225, 20191226, 20191227, 20191228, 20191229, 20191230]
tf.Tensor(
[20191220. 20191222. 20191224. 20191224. 20191224. 20191226. 20191228.
 20191228. 20191228. 20191230.], shape=(10,), dtype=float32)

Hope this answers your question. Happy Learning.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM