简体   繁体   English

来自 tf.keras.preprocessing.image.ImageDataGenerator.flow_from_directory 的 tf.data.Dataset?

[英]tf.data.Dataset from tf.keras.preprocessing.image.ImageDataGenerator.flow_from_directory?

How do I create a tf.data.Dataset from tf.keras.preprocessing.image.ImageDataGenerator.flow_from_directory ?如何创建一个tf.data.Datasettf.keras.preprocessing.image.ImageDataGenerator.flow_from_directory

I'm considering tf.data.Dataset.from_generator , but it's unclear how to acquire the output_types keyword argument for it, given the return type:我正在考虑tf.data.Dataset.from_generator ,但不清楚如何为其获取output_types关键字参数,给定返回类型:

A DirectoryIterator yielding tuples of (x, y) where x is a numpy array containing a batch of images with shape (batch_size, *target_size, channels) and y is a numpy array of corresponding labels.产生(x, y)元组的DirectoryIterator (x, y)其中x是一个 numpy 数组,其中包含一批形状为(batch_size, *target_size, channels)y是相应标签的 numpy 数组。

Both batch_x and batch_y in ImageDataGenerator are of type K.floatx() , so must be tf.float32 by default. ImageDataGenerator中的batch_xbatch_y都是K.floatx()类型,因此默认情况下必须是tf.float32

Similar question was discussed already at How to use Keras generator with tf.data API .类似的问题已经在How to use Keras generator with tf.data API 中讨论过。 Let me copy-paste the answer from there:让我从那里复制粘贴答案:

def make_generator():
    train_datagen = ImageDataGenerator(rescale=1. / 255)
    train_generator = 
    train_datagen.flow_from_directory(train_dataset_folder,target_size=(224, 224), class_mode='categorical', batch_size=32)
    return train_generator

train_dataset = tf.data.Dataset.from_generator(make_generator,(tf.float32, tf.float32))

The author faced another issue with the graph scope, but I guess it is unrelated to your question.作者面临着图形范围的另一个问题,但我想这与您的问题无关。

Or as a one liner:或者作为一个班轮:

tf.data.Dataset.from_generator(lambda:
    ImageDataGenerator().flow_from_directory('folder_path'),(tf.float32, tf.float32))

Here is my solution.这是我的解决方案。 To show how it works, I use cats/dogs datasets:为了展示它是如何工作的,我使用了猫/狗数据集:

import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf


_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
#'/Users/mustafamuratarat/.keras/datasets/cats_and_dogs_filtered/train'

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

gen = img_gen.flow_from_directory(train_dir, target_size=(160, 160), batch_size=32)
#<tensorflow.python.keras.preprocessing.image.DirectoryIterator at 0x7fb9fde3b250>

#gen.class_indices
#{'cats': 0, 'dogs': 1}

#gen.target_size
#(160, 160)

# gen.batch_size
# 32

# gen.num_classes
# 2

dataset = tf.data.Dataset.from_generator(
    lambda: gen,
    output_types = (tf.float32, tf.float32),
    output_shapes = ([None, 160, 160, 3], [None, 2]),
)

#list(dataset.take(1).as_numpy_iterator())

Then you can feed dataset object to any model.然后您可以将dataset对象提供给任何模型。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 使用 tf.keras.preprocessing.image_dataset_from_directory 的 tf.data.Dataset 训练模型是非常慢的 keras - train model using tf.data.Dataset of tf.keras.preprocessing.image_dataset_from_directory is very slow keras 获取ValueError:使用来自tf.contrib.keras.preprocessing.image.ImageDatagenerator.flow的序列设置数组元素 - Getting ValueError: setting an array element with a sequence from tf.contrib.keras.preprocessing.image.ImageDatagenerator.flow tf.data 与 tf.keras.preprocessing.image.ImageDataGenerator - tf.data vs tf.keras.preprocessing.image.ImageDataGenerator 如何从 tf.keras.preprocessing.image_dataset_from_directory() 探索和修改创建的数据集? - How can I explore and modify the created dataset from tf.keras.preprocessing.image_dataset_from_directory()? 当我们在 tf.keras.preprocessing.image_dataset_from_directory 对象上使用 .next() 或 .take() 时,我们是否会丢失数据? - Are we loosing data when we use .next() or .take() on tf.keras.preprocessing.image_dataset_from_directory object? 来自 tf.keras.preprocessing.image_dataset_from_directory 的 x_test 和 y_test - x_test and y_test from tf.keras.preprocessing.image_dataset_from_directory 设置一次后更改 tf.keras.preprocessing.image_dataset_from_directory 的 label_mode - Changing label_mode of tf.keras.preprocessing.image_dataset_from_directory after setting it once 如何使用 tf.keras.preprocessing.image_dataset_from_directory 获取类的数量? - how to obtain the number of classes using tf.keras.preprocessing.image_dataset_from_directory? 无法将 tf.keras.preprocessing.image_dataset_from_directory 转换为 np.array - Cannot convert tf.keras.preprocessing.image_dataset_from_directory to np.array tf.keras.preprocessing.image_dataset_from_directory 如何将 output 显示到控制台 - How tf.keras.preprocessing.image_dataset_from_directory display output to console
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM