简体   繁体   中英

How can I explore and modify the created dataset from tf.keras.preprocessing.image_dataset_from_directory()?

Here's how I used the function:

dataset = tf.keras.preprocessing.image_dataset_from_directory(
    main_directory,
    labels='inferred',
    image_size=(299, 299),
    validation_split=0.1,
    subset='training',
    seed=123
)

I'd like to explore the created dataset much like in this example , particularly the part where it was converted to a pandas dataframe. But my minimum goal is to check the labels and the number of files attached to it, just to check if, indeed, it created the dataset as expected (sub-directory being the corresponding label of images inside it).

To be clear, the main_directory is set up like this:

main_directory
- class_a
  - 000.jpg
  - ...
- class_b
  - 100.jpg
  - ...

And I'd like to see the dataset display its info with something like this:

label     number of images
class_a   100
class_b   100

Additionally, is it possible to remove labels and corresponding images in a dataset? The idea is to drop them if the corresponding number of images is less than a certain number, or a different metric. It can be of course done outside this function through other means, but I'd like to know if it is indeed possible, and if so, how.

EDIT: For additional context, the end goal of all of this is to train a pre-trained model like this with local images divided into folders named after their classes. If there is a better way that includes not using that function and meets this end goal, it's welcome all the same. Thanks!

I think it would be much easier to use glob2 to get all your filenames, process them as you want to, then make a simple loading function that will replace image_dataset_from_directory .

Get all your files:

files = glob2.glob('class_*\\*.jpg')

Then manipulate this list of filenames as desired.

Then, make a function to load the images:

def load(file_path):
    img = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.resize(img, size=(299, 299))
    label = tf.strings.split(file_path, os.sep)[0]
    label = tf.cast(tf.equal(label, 'class_a'), tf.int32)
    return img, label

Then create your dataset for training:

train_ds = tf.data.Dataset.from_tensor_slices(files).map(load).batch(4)

Then train:

model.fit(train_ds)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM