简体   繁体   中英

Tensorflow concat tf.data.Dataset Batches

When using tf.data.Dataset is it possible to concatenate batches of datasets in a way such that not the second dataset is concatenated at the end of the first, but such that the first batch of the second dataset is concatenated after the first batch of the second dataset and so on.

I tried it as following but this gave me a dataset with length 40, however, I would expect length 80 here.

train_data = train_data.batch(40).concatenate(augmentation_data.batch(40))

Not exactly sure what your usecase is, but you might want to concat the tensors of features and labels in the batch separately like this:

def concat_batches(x, y):
    features1, labels1 = x
    features2, labels2 = y
    return ({feature: tf.concat([features1[feature], features2[feature]], axis=0) for feature in features1.keys()}, tf.concat([labels1, labels2], axis=0))

Here an example:

dataset = tf.data.Dataset.from_tensor_slices(({"test": [[1], [1], [1], [1]]}, [1, 1, 1, 1]))
b1 = dataset.repeat().batch(3).make_one_shot_iterator().get_next()
dataset2 = tf.data.Dataset.from_tensor_slices(({"test": [[2], [2], [2], [2]]}, [2, 2, 2, 2]))
b2 = dataset2.repeat().batch(3).make_one_shot_iterator().get_next()

b_con = concat_batches(b1, b2) #tensors of batches 1 and 2 have shape (3, 1), features of the concatenated batch (6, 1)

When evaluating the example you will see, that b_con will look like this:

({'test': array([[1],
       [1],
       [1],
       [2],
       [2],
       [2]], dtype=int32)}, array([1, 1, 1, 2, 2, 2], dtype=int32))

Hope this helps!

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM