简体   繁体   English

如何使用 tf.data.Dataset.map 对两个 tf.data.Datasets 进行元素总和,两者都无限迭代?

[英]How to do the element-wise sum of two tf.data.Datasets, both iterating indefinitely, with tf.data.Dataset.map?

I would like to write a mixup data augmentation [1] function in my tf.data -based pipeline.我想在基于tf.data的管道中编写一个混合数据增强 [1] function。

I generate one tf.data.Dataset with my training examples and one with the examples I want to use to augment my training examples.我用我的训练示例生成一个tf.data.Dataset和一个我想用来增强我的训练示例的示例。

I want to map the elements feat_train, label_train of dataset_train to feat_train + feat_aug, label_train, label_aug , feat_aug and label_aug being the elements of dataset_aug , such that both datasets are indefinitely iterated, eg for a dataset_train with 3 elements and dataset_aug with 2 elements:我想 map dataset_train 的元素feat_train、 label_train 到feat_train + feat_aug、label_train、label_augfeat_auglabel_augdataset_aug的元素,这样两个数据集都可以无限迭代,例如对于具有 3 个元素的 dataset_train 和具有 2 个元素的 dataset_aug:

feat_train[0], label_train[0] -> feat_train[0] + feat_aug[0], label_train[0] + label_aug[0] feat_train[0], label_train[0] -> feat_train[0] + feat_aug[0], label_train[0] + label_aug[0]
feat_train[1], label_train[1] -> feat_train[1] + feat_aug[1], label_train[1] + label_aug[1] feat_train[1], label_train[1] -> feat_train[1] + feat_aug[1], label_train[1] + label_aug[1]
feat_train[2], label_train[2] -> feat_train[2] + feat_aug[0], label_train[2] + label_aug[0] feat_train[2], label_train[2] -> feat_train[2] + feat_aug[0], label_train[2] + label_aug[0]
feat_train[0], label_train[0] -> feat_train[0] + feat_aug[1], label_train[0] + label_aug[1] feat_train[0], label_train[0] -> feat_train[0] + feat_aug[1], label_train[0] + label_aug[1]
feat_train[1], label_train[1] -> feat_train[1] + feat_aug[0], label_train[1] + label_aug[0] feat_train[1], label_train[1] -> feat_train[1] + feat_aug[0], label_train[1] + label_aug[0]
... ...

How can I get this behavior in my mixup fonction?如何在我的混合功能中获得这种行为? Is there any other recommended way to perform element-wise operations on 2 tf.data.Datasets iterating indefinitely?有没有其他推荐的方法来对 2 tf.data.Datasets无限迭代地执行元素操作?

[1] Zhang, Hongyi, et al. [1] 张宏义,等。 "mixup: Beyond empirical risk minimization." “混淆:超越经验风险最小化。” arXiv preprint arXiv:1710.09412 (2017). arXiv 预印本 arXiv:1710.09412 (2017)。

# files_train and files_aug are lists of TFRecord files.

# parse TFRecords to get training example features and
# one-hot encoded labels
dataset_train = tf.data.TFRecordDataset(files_train)
dataset_train = dataset_train.map(
    lambda x: serialized2data(x, feature_shape, class_list))
dataset_train = dataset_train.shuffle(10000)
dataset_train = dataset_train.repeat()  # Repeat indefinitely.

# parse TFRecords to get augmentation example features and
# one-hot encoded labels
dataset_aug = tf.data.TFRecordDataset(files_aug)
dataset_aug = dataset_aug.map(
    lambda x: serialized2data(x, feature_shape, class_list))
dataset_aug = dataset_aug.repeat()  # Repeat indefinitely.

# augment data (mixup)
# Here how can I write a map function so that the features of every item
# of dataset_train is mixed with an item of dataset_aug ?
# something like
# dataset_train = dataset_train.map(
#     lambda feat_train, label_train: mixup(
#         feat_train, label_train, feat_aug, label_aug)
# )
# ?
# but how can I iterate dataset_aug to get feat_aug and label_aug ?

# make batch
dataset_train = dataset_train.batch(batch_size, drop_remainder=True)

return dataset


def mixup(feat_train, label_train, feat_aug, label_aug):
    # Shown as an example. This will be more complicated...
    return (feat_train + feat_aug,
            label_train + label_aug)


def serialized2data(
        serialized_data,
        feature_shape,
        class_list,
        data_format='channels_first',
        training=True):
    """Generate features, labels and, if training is False, filenames and times.
    Labels are indices of original label in class_list.

    Args:
        serialized_data: data serialized using utils.tf_utils.serialize_data
        feature_shape: shape of the features. Can be obtained with
            feature_extractor.feature_shape (see utils.feature_utils)
        class_list: list of class ids (used for one-hot encoding the labels)
        data_format: 'channels_first' (NCHW) or 'channels_last' (NHWC).
            Default is set to 'channels_first' because it is faster on GPU
            (https://www.tensorflow.org/guide/performance/overview#data_formats).
    """

    features = {
        'filename': tf.io.FixedLenFeature([], tf.string),
        'times': tf.io.FixedLenFeature([2], tf.float32),
        'data': tf.io.FixedLenFeature(feature_shape, tf.float32),
        'labels': tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(serialized_data, features)

    # reshape data to channels_first format
    if data_format == 'channels_first':
        data = tf.reshape(example['data'], (1, feature_shape[0], feature_shape[1]))
    else:
        data = tf.reshape(example['data'], (feature_shape[0], feature_shape[1], 1))

    # one-hot encode labels
    labels = tf.strings.to_number(
        tf.string_split([example['labels']], '#').values,
        out_type=tf.int32
    )

    # get intersection of class_list and labels
    labels = tf.squeeze(
        tf.sparse.to_dense(
            tf.sets.intersection(
                tf.expand_dims(labels, axis=0),
                tf.expand_dims(class_list, axis=0)
            )
        ),
        axis=0
    )

    # sort class_list and get indices of labels in class_list
    class_list = tf.sort(class_list)
    labels = tf.where(
        tf.equal(
            tf.expand_dims(labels, axis=1),
            class_list)
    )[:,1]

    tf.cond(
        tf.math.logical_and(training, tf.equal(tf.size(labels), 0)),
        true_fn=lambda:myprint(tf.strings.format('File {} has no label', example['filename'])),
        false_fn=lambda:1
    )

    one_hot = tf.cond(
        tf.equal(tf.size(labels), 0),
        true_fn=lambda: tf.zeros(tf.size(class_list)),
        false_fn=lambda: tf.reduce_max(tf.one_hot(labels, tf.size(class_list)), 0)
    )

    if training:
        return (data, one_hot)
    else:
        return (data, one_hot, example['filename'], example['times'])

I am giving a sample code on how you can achieve the objective you have asked.我正在提供一个示例代码,说明如何实现您所要求的目标。 I have created train_dataset and aug_dataset of length 3 and 2 respectively.我分别创建了长度为 3 和 2 的train_datasetaug_dataset Both have images and labels.两者都有图像和标签。 Images are of shape (64, 64, 3).图像具有形状 (64, 64, 3)。 Labels of train are [10, 20, 30] and aug are [1, 2]. train的标签是 [10, 20, 30], aug是 [1, 2]。

Pay particular attention to labels output and see that they are repeating the way you wanted.特别注意标签 output 并查看它们是否以您想要的方式重复。

import numpy as np
import tensorflow as tf

tf.enable_eager_execution()

train_dataset = tf.data.Dataset.from_tensor_slices((np.random.rand(3, 64, 64, 3), 
                                                    np.array([10, 20, 30])))
aug_dataset = tf.data.Dataset.from_tensor_slices((np.random.rand(2, 64, 64, 3), 
                                                  np.arange(1, 3)))

train_dataset = train_dataset.repeat()
aug_dataset = aug_dataset.repeat()

dataset = tf.data.Dataset.zip((train_dataset, aug_dataset))

def add_datasets(dataset1, dataset2):
  image_data = dataset1[0] + dataset2[0]
  label_data = dataset1[1] + dataset2[1]
  return image_data, label_data

dataset = dataset.map(add_datasets)

for a, b in dataset:
  print(a.shape, b)

Output: Output:

(64, 64, 3) tf.Tensor(11, shape=(), dtype=int64)
(64, 64, 3) tf.Tensor(22, shape=(), dtype=int64)
(64, 64, 3) tf.Tensor(31, shape=(), dtype=int64)
(64, 64, 3) tf.Tensor(12, shape=(), dtype=int64)
(64, 64, 3) tf.Tensor(21, shape=(), dtype=int64)
(64, 64, 3) tf.Tensor(32, shape=(), dtype=int64)
(64, 64, 3) tf.Tensor(11, shape=(), dtype=int64)
(64, 64, 3) tf.Tensor(22, shape=(), dtype=int64)
(64, 64, 3) tf.Tensor(31, shape=(), dtype=int64)
(64, 64, 3) tf.Tensor(12, shape=(), dtype=int64)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM