简体   繁体   English

使用 Keras 的 ImageDataGenerator 时的文件夹结构

[英]Folder structure when using Keras' ImageDataGenerator

I have a training set of images with structure like this:我有一组训练图像,其结构如下:

/howler-monkey/
    1.jpg
    2.jpg
    ...jpg
/japanese-mcaque
    1.jpg
    2.jpg
    ...

for 10 classes. 10 节课。

I am trying to augment the images and save them to disk, but I would like to preserve the folder structure, so:我正在尝试增加图像并将它们保存到磁盘,但我想保留文件夹结构,因此:

/augmented/
    /howler-monkey
        aug_1.jpg
        aug_2.jpg
    /japanese-mcaque
        aug_1.jpg
        aug_2.jpg

It seems when I simply run with似乎当我简单地运行时

trainDataGenerator = ImageDataGenerator(shear_range=0.2, zoom_range=0.2,
    horizontal_flip=True, rotation_range=20, width_shift_range=0.2,
    height_shift_range=0.2, fill_mode='nearest')

fileIterator = trainDataGenerator.flow_from_directory('{}/training'.format(args.dataset),
    save_to_dir='{}/{}'.format(args.dataset, args.output))

i = 0
for image in fileIterator:
    if i > 10:
        break

It dumps augmented images in the augmented/ folder, but it doesn't save the directory structure, making it hard to use to train.它将增强图像转储到augmented/文件夹中,但它不保存目录结构,因此很难用于训练。

How can I preserve the original directory structure when augmenting images?增强图像时如何保留原始目录结构?

So I ended up just using .flow() and pathlib to create the directories manually:所以我最终只使用.flow()pathlib手动创建目录:

trainDataGenerator = ImageDataGenerator(shear_range=0.2, zoom_range=0.2,
    horizontal_flip=True, rotation_range=20, width_shift_range=0.2,
    height_shift_range=0.2)

for path in list_images(args.dataset):
    img = cv2.imread(path)
    img = img_to_array(img)
    img = np.expand_dims(img, axis=0) 

    pathlib.Path('{}/{}/{}'.format(args.dataset, args.output,
            path.split(os.path.sep)[-2])).mkdir(
        parents=True, exist_ok=True)
    print(path)
    total = 0
    for image in trainDataGenerator.flow(img, batch_size=1,
            save_to_dir='{}/{}/{}'.format(args.dataset, args.output,
            path.split(os.path.sep)[-2]), save_format='jpeg'):

            print(total)
            total += 1
            if total == 10:
                break

where args.dataset is a str which contains the training images and args.output is a str that contains augmentedImages .其中args.dataset是一个包含训练图像的 str , args.output是一个包含augmentedImages的 str 。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM