简体   繁体   English

使用新的TensorFlow Dataset API读取TFRecord图像数据

[英]Read TFRecord image data with new TensorFlow Dataset API

I am having trouble reading TFRecord format image data using the "new" (TensorFlow v1.4) Dataset API. 我在使用“新”(TensorFlow v1.4)数据集API读取TFRecord格式图像数据时遇到麻烦。 I believe the problem is that I am somehow consuming the whole dataset instead of a single batch when trying to read. 我认为问题在于尝试读取数据时,我会以某种方式使用整个数据集而不是单个批处理。 I have a working example of doing this using the batch/file-queue API here: https://github.com/gnperdue/TFExperiments/tree/master/conv (well, in the example I am running a classifier, but the code to read the TFRecord images is in the DataReaders.py class). 我这里有一个使用batch / file-queue API执行此操作的示例: https : //github.com/gnperdue/TFExperiments/tree/master/conv (在示例中,我正在运行分类器,但是代码读取TFRecord图像位于DataReaders.py类中)。

The problem functions are, I believe, these: 我认为问题功能包括:

def parse_mnist_tfrec(tfrecord, features_shape):
    tfrecord_features = tf.parse_single_example(
        tfrecord,
        features={
            'features': tf.FixedLenFeature([], tf.string),
            'targets': tf.FixedLenFeature([], tf.string)
        }
    )
    features = tf.decode_raw(tfrecord_features['features'], tf.uint8)
    features = tf.reshape(features, features_shape)
    features = tf.cast(features, tf.float32)
    targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
    targets = tf.one_hot(indices=targets, depth=10, on_value=1, off_value=0)
    targets = tf.cast(targets, tf.float32)
    return features, targets

class MNISTDataReaderDset:
    def __init__(self, data_reader_dict):
        # doesn't matter here

    def batch_generator(self, num_epochs=1):
        def parse_fn(tfrecord):
            return parse_mnist_tfrec(
                tfrecord, self.name, self.features_shape
            )
        dataset = tf.data.TFRecordDataset(
            self.filenames_list, compression_type=self.compression_type
        )
        dataset = dataset.map(parse_fn)
        dataset = dataset.repeat(num_epochs)
        dataset = dataset.batch(self.batch_size)
        iterator = dataset.make_one_shot_iterator()
        batch_features, batch_labels = iterator.get_next()
        return batch_features, batch_labels

Then, in use: 然后,在使用中:

        batch_features, batch_labels = \
            data_reader.batch_generator(num_epochs=1)

        sess.run(tf.local_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            # look at 3 batches only
            for _ in range(3):
                labels, feats = sess.run([
                    batch_labels, batch_features
                ])

This generates an error like: 这将产生如下错误:

 [[Node: Reshape_1 = Reshape[T=DT_UINT8, Tshape=DT_INT32](DecodeRaw_1, Reshape_1/shape)]]
 Input to reshape is a tensor with 50000 values, but the requested shape has 1
 [[Node: Reshape_1 = Reshape[T=DT_UINT8, Tshape=DT_INT32](DecodeRaw_1, Reshape_1/shape)]]
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,28,28,1], [?,10]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

Does anyone have any ideas? 有人有什么想法吗?

I have a gist with the full code in the reader example and a link to the TFRecord files (our old, good friend MNIST, in TFRecord form) here: 我在阅读器示例中有完整代码的要旨,并在此处指向TFRecord文件(我们的老朋友MNIST,以TFRecord形式)的链接:

https://gist.github.com/gnperdue/56092626d611ae23370a21fdeeb2abe8 https://gist.github.com/gnperdue/56092626d611ae23370a21fdeeb2abe8

Thanks! 谢谢!

Edit - I also tried a flat_map , eg: 编辑-我也尝试了flat_map ,例如:

def batch_generator(self, num_epochs=1):
    """
    TODO - we can use placeholders for the list of file names and
    init with a feed_dict when we call `sess.run` - give this a
    try with one list for training and one for validation
    """
    def parse_fn(tfrecord):
        return parse_mnist_tfrec(
            tfrecord, self.name, self.features_shape
        )
    dataset = tf.data.Dataset.from_tensor_slices(self.filenames_list)
    dataset = dataset.flat_map(
        lambda filename: (
            tf.data.TFRecordDataset(
                filename, compression_type=self.compression_type
            ).map(parse_fn).batch(self.batch_size)
        )
    )
    dataset = dataset.repeat(num_epochs)
    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels

I also tried using just one file and not a list (in my first way of approaching this above). 我还尝试仅使用一个文件而不使用列表(以上述第一种方式)。 No matter what, it seems TF always wants to eat the entire file into the TFRecordDataset and won't operate on single records. 无论如何,TF似乎总是希望将整个文件都吃进TFRecordDataset并且不会对单个记录进行操作。

Okay, I figured this out - the code above is fine. 好的,我知道了-上面的代码很好。 The problem was my script for creating the TFRecords. 问题是我创建TFRecords的脚本。 Basically, I had a block like this 基本上,我有一个像这样的方块

def write_tfrecord(reader, start_idx, stop_idx, tfrecord_file):
    writer = tf.python_io.TFRecordWriter(tfrecord_file)
    tfeat, ttarg = get_binary_data(reader, start_idx, stop_idx)
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'features': tf.train.Feature(
                    bytes_list=tf.train.BytesList(value=[tfeat])
                ),
                'targets': tf.train.Feature(
                    bytes_list=tf.train.BytesList(value=[ttarg])
                )
            }
        )
    )
    writer.write(example.SerializeToString())
    writer.close()

and I needed a block like this instead: 我需要一个像这样的块:

def write_tfrecord(reader, start_idx, stop_idx, tfrecord_file):
    writer = tf.python_io.TFRecordWriter(tfrecord_file)
    for idx in range(start_idx, stop_idx):
        tfeat, ttarg = get_binary_data(reader, idx)
        example = tf.train.Example(
            features=tf.train.Features(
                feature={
                    'features': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[tfeat])
                    ),
                    'targets': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[ttarg])
                    )
                }
            )
        )
        writer.write(example.SerializeToString())
    writer.close()

Which is to say - I was basically writing my entire block of data as one giant TFRecord when I needed be making one per example in the data. 也就是说-当我需要在数据中为每个示例制作一个时,我基本上将我的整个数据块写为一个巨型TFRecord。

It turns out if you do it either way in the old file and batch-queue API everything works - the functions like tf.train.batch are auto-magically 'smart' enough to either carve the big block up or concatenate lots of single-example records into a batch depending on what you give it. 事实证明,如果您在旧文件和批处理队列API中以任何一种方式进行操作,一切都会正常工作tf.train.batch之类的功能会自动神奇地“智能”到足以分割大块或连接大量单示例根据您提供的内容成批记录。 When I fixed my code that made the TFRecords file, I didn't need to change anything in my old file and batch-queue code and it still consumed the TFRecords file just fine. 修复制作TFRecords文件的代码时,我不需要更改旧文件和批处理队列代码中的任何内容,并且它仍然可以正常使用TFRecords文件。 However, the Dataset API is sensitive to this difference. 但是, Dataset API对此差异很敏感。 That is why in my code above it always appeared to be consuming the entire file - its because the entire file really was one big TFRecord. 这就是为什么在我上面的代码中它总是似乎在消耗整个文件-这是因为整个文件确实是一个很大的TFRecord。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM