简体   繁体   中英

Using Keras APIs, how can I import images in batches with exactly K instances of each ID in a given batch?

I'm trying to implement batch hard triplet loss, as seen in Section 3.2 of https://arxiv.org/pdf/2004.06271.pdf .

I need to import my images so that each batch has exactly K instances of each ID in a particular batch . Therefore, each batch must be a multiple of K .

I have a directory of images too large to fit into memory and therefore I am using ImageDataGenerator.flow_from_directory() to import the images, but I can't see any parameters for this function to allow the functionality I need.

How can I achieve this batch behaviour using Keras?

As of Tensorflow 2.4 I don't see a standard way of doing that with an ImageDataGenerator .

So I think you need to write your own based on the tensorflow.keras.utils.Sequence class, so you are free to define the batch contents yourself.


You can try merging several data streams together in a controlled manner.

Given you have K instances of tf.data.Dataset (does not matter how you instantiate them) that are responsible for supplying training instances of particular IDs, you can concatenate them to get even distribution inside a mini-batch:

ds1 = ...  # Training instances with ID == 1
ds2 = ...  # Training instances with ID == 2
dsK = ... # Training instances with ID == K

train_dataset = tf.data.Dataset.zip((ds1, ds2, ..., dsK)).flat_map(concat_datasets).batch(batch_size=N * K)

where the concat_datasets is the merge function:

def concat_datasets(*datasets):
    ds = tf.data.Dataset.from_tensors(datasets[0])
    for i in range(1, len(datasets)):
        ds = ds.concatenate(tf.data.Dataset.from_tensors(datasets[i]))
    return ds

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

粤ICP备18138465号  © 2020-2024 STACKOOM.COM