繁体   English   中英

Tensorflow:具有任意尺寸张量的批处理TFRecord数据集

[英]Tensorflow: Batch TFRecord Dataset with tensors of arbitrary dimensions

如何使用TFRecordsDataset批处理任意形状的张量?

我目前正在研究对象检测网络的输入管道,并且正在为标签批量处理而苦苦挣扎。 标签由边界框坐标和图像中的对象类别组成。 由于图像中可能有多个对象,因此标签尺寸是任意的


使用tf.train.batch ,可以设置dynamic_padding=True以使形状适合相同的尺寸。 但是, data.TFRecordDataset.batch()没有此类选项。

我想要批处理的所需形状对于我的[batch_size, arbitrary , 4][batch_size, arbitrary , 4]对于类为[batch_size, arbitrary, 1]

def decode(serialized_example):
"""
Decodes the information of the TFRecords to image, label_coord, label_classes
Later on will also contain the Image Sequence!

:param serialized_example: Serialized Example read from the TFRecords
:return: image, label_coordinates list, label_classes list
"""
features = {'image/shape': tf.FixedLenFeature([], tf.string),
            'train/image': tf.FixedLenFeature([], tf.string),
            'label/coordinates': tf.VarLenFeature(tf.float32),
            'label/classes': tf.VarLenFeature(tf.string)}

features = tf.parse_single_example(serialized_example, features=features)

image_shape = tf.decode_raw(features['image/shape'], tf.int64)
image = tf.decode_raw(features['train/image'], tf.float32)
image = tf.reshape(image, image_shape)

# Contains the Bounding Box coordinates in a flattened tensor
label_coord = features['label/coordinates']
label_coord = label_coord.values
label_coord = tf.reshape(label_coord, [1, -1, 4])

# Contains the Classes of the BBox in a flattened Tensor
label_classes = features['label/classes']
label_classes = label_classes.values
label_classes = tf.reshape(label_classes, [1, -1, 1])


return image, label_coord, label_classes

    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.map(decode)
    dataset = dataset.map(augment)
    dataset = dataset.map(normalize)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)

    dataset = dataset.batch(batch_size)

引发的错误是Cannot batch tensors with different shapes in component 1. First element had shape [1,1,4] and element 1 had shape [1,7,4].

目前, augmentnormalize功能还只是占位符。

事实证明, tf.data.TFRecordDataset还有另一个名为padded_batch函数,它基本上可以完成tf.train.batch(dynamic_pad=True)工作。 这很容易解决了问题...

dataset = tf.data.TFRecordDataset(filename)

dataset = dataset.map(decode)
dataset = dataset.map(augment)
dataset = dataset.map(normalize)

dataset = dataset.shuffle(1000+3*batch_size)
dataset = dataset.repeat(num_epochs)
dataset = dataset.padded_batch(batch_size,
                               drop_remainder=False,
                               padded_shapes=([None, None, None],
                                              [None, 4],
                                              [None, 1])
                              )

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM