简体   繁体   中英

Using Keras APIs, how can I import images in batches with exactly K instances of each ID in a given batch?

I'm trying to implement batch hard triplet loss, as seen in Section 3.2 of https://arxiv.org/pdf/2004.06271.pdf .

I need to import my images so that each batch has exactly K instances of each ID in a particular batch . Therefore, each batch must be a multiple of K .

I have a directory of images too large to fit into memory and therefore I am using ImageDataGenerator.flow_from_directory() to import the images, but I can't see any parameters for this function to allow the functionality I need.

How can I achieve this batch behaviour using Keras?

As of Tensorflow 2.4 I don't see a standard way of doing that with an ImageDataGenerator .

So I think you need to write your own based on the tensorflow.keras.utils.Sequence class, so you are free to define the batch contents yourself.

References:
https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence
https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly

You can try merging several data streams together in a controlled manner.

Given you have K instances of tf.data.Dataset (does not matter how you instantiate them) that are responsible for supplying training instances of particular IDs, you can concatenate them to get even distribution inside a mini-batch:

ds1 = ...  # Training instances with ID == 1
ds2 = ...  # Training instances with ID == 2
...
dsK = ... # Training instances with ID == K



train_dataset = tf.data.Dataset.zip((ds1, ds2, ..., dsK)).flat_map(concat_datasets).batch(batch_size=N * K)

where the concat_datasets is the merge function:

def concat_datasets(*datasets):
    ds = tf.data.Dataset.from_tensors(datasets[0])
    for i in range(1, len(datasets)):
        ds = ds.concatenate(tf.data.Dataset.from_tensors(datasets[i]))
    return ds

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM