简体   繁体   English

批量导入图片,解决非分类问题

[英]Import batches of images for non-classification problems

I'm trying to do a superresolution network, but I am having trouble importing my own data.我正在尝试做一个超分辨率网络,但我无法导入我自己的数据。 I have two types of images: resized images (smaller), original images.我有两种类型的图像:调整大小的图像(较小)、原始图像。 The first one is going to be used as an input of the network and the second ones will be used for training the network.第一个将用作网络的输入,第二个将用于训练网络。

The problem is that I need to load my images in batches because my computer doesn't have enough GPU memory for constructing the whole dataset at once.问题是我需要批量加载我的图像,因为我的计算机没有足够的 GPU memory 来一次构建整个数据集。 I thought that using the following code could work:我认为使用以下代码可以工作:

train_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

the problem is that I only know how to make it work for classification problems because, as far as I am concerned, it's designed for only having training and validation dataset.问题是我只知道如何使它适用于分类问题,因为就我而言,它是为仅具有训练和验证数据集而设计的。

For doing the superresolution I need four datasets:为了进行超分辨率,我需要四个数据集:

normal-size-train正常大小的火车

small-size-train小火车

normal-size-test正常大小测试

small-size-test小规模测试

NOTE: My program works when I create a tensor for resized images and another one for original images, but now I want to implement a larger dataset.注意:当我为调整大小的图像创建一个张量并为原始图像创建另一个张量时,我的程序可以工作,但现在我想实现一个更大的数据集。

I think it's better implement a data generator for this task.我认为最好为这个任务实现一个数据生成器。 here an example.这里是一个例子。 you can/must add image reshaping if the images in the dateset does not have the same shape.如果日期集中的图像不具有相同的形状,您可以/必须添加图像重塑。

def image_generator(path, batch_size=16):
    list_path = glob.glob(path)
    index = 0
    list_of_low_dim_images = []
    list_of_high_dim_images = []
    size = len(list_path)
    while True:
        index +=1
        for path in list_path:
            path2 = path.replace("small", "normal")
            small_img = tf.io.read_file(path)
            small_img = decode_img(small_img)
            normal_img = tf.io.read_file(path2)
            normal_img = decode_img(normal_img)
            list_of_low_dim_images.append(small_img)
            list_of_high_dim_images.append(normal_img)
            if index == batch_size:
                inedx = 0
                yield list_of_low_dim_images,list_of_high_dim_images

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM