简体   繁体   English

TF2.0 图像生成器无法使用 Keras ImageDataGenerator

[英]TF2.0 image generator not working using Keras ImageDataGenerator

I am trying to use the Dataset with TF2.0 along with keras ImageDataGenerator, but when I try to call it it give me an error.我试图将数据集与 TF2.0 和 keras ImageDataGenerator 一起使用,但是当我尝试调用它时,它给了我一个错误。 So this is what I am doing.所以这就是我正在做的。 I have a Data folder where there are 4 folder for each type of category.我有一个 Data 文件夹,其中每种类别有 4 个文件夹。 This I assume will be the label as like the old keras method.我认为这将是像旧的 keras 方法一样的标签。 There 4 forlders have 72 or so images in them.有 4 个 folders 有 72 个左右的图像。

Here is the code that I am using to generate the code这是我用来生成代码的代码

augment = True
if augment:
    train_datagen = ImageDataGenerator(
        rescale=1./ 255,
        shear_range=0,
        rotation_range=20,
        zoom_range=0.15,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')  # set validation split
else:
    train_datagen = ImageDataGenerator(
        rescale=1./ 255,
        horizontal_flip=True,
        fill_mode='nearest')  # set validation split

images, labels = next(train_datagen.flow_from_directory(DATA_PATH))
print(images.dtype, images.shape)
print(labels.dtype, labels.shape)
input_shape = images.shape[1:]
print("InputShape:", input_shape)
img_shape = (input_shape[0], input_shape[1])

ds = tf.data.Dataset.from_generator(train_datagen.flow_from_directory,
             args=[DATA_PATH], output_types=(tf.float32, tf.float32))

This produces this:这产生了这个:

 Found 324 images belonging to 4 classes. float32 (32, 256, 256, 3) float32 (32, 4) InputShape: (256, 256, 3) DS: <DatasetV1Adapter shapes: (<unknown>, <unknown>), types: (tf.float32, tf.float32)>

So that looks right to me.所以这对我来说是正确的。 So when I try to use it in my model like this所以当我尝试在我的模型中使用它时

history = model.fit(ds, epochs=10, verbose=1)

It gives me this error:它给了我这个错误:

 Epoch 1/10 Traceback (most recent call last): File "C:/Users/gus/Documents/ImageSimularity/FoodTrainer.py", line 75, in <module> history = model.fit(ds, epochs=10, verbose=1) File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\keras\\engine\\training.py", line 728, in fit use_multiprocessing=use_multiprocessing) File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\keras\\engine\\training_v2.py", line 324, in fit total_epochs=epochs) File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\keras\\engine\\training_v2.py", line 123, in run_one_epoch batch_outs = execution_function(iterator) File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\keras\\engine\\training_v2_utils.py", line 86, in execution_function distributed_function(input_fn)) File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\eager\\def_function.py", line 457, in __call__ result = self._call(*args, **kwds) File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\eager\\def_function.py", line 503, in _call self._initialize(args, kwds, add_initializers_to=initializer_map) File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\eager\\def_function.py", line 408, in _initialize *args, **kwds)) File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\eager\\function.py", line 1848, in _get_concrete_function_internal_garbage_collected graph_function, _, _ = self._maybe_define_function(args, kwargs) File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\eager\\function.py", line 2150, in _maybe_define_function graph_function = self._create_graph_function(args, kwargs) File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\eager\\function.py", line 2041, in _create_graph_function capture_by_value=self._capture_by_value), File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\framework\\func_graph.py", line 915, in func_graph_from_py_func func_outputs = python_func(*func_args, **func_kwargs) File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\eager\\def_function.py", line 358, in wrapped_fn return weak_wrapped_fn().__wrapped__(*args, **kwds) File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\keras\\engine\\training_v2_utils.py", line 66, in distributed_function model, input_iterator, mode) File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\keras\\engine\\training_v2_utils.py", line 112, in _prepare_feed_values inputs, targets, sample_weights = _get_input_from_iterator(inputs) File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\keras\\engine\\training_v2_utils.py", line 149, in _get_input_from_iterator distribution_strategy_context.get_strategy(), x, y, sample_weights) File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\keras\\distribute\\distributed_training_utils.py", line 308, in validate_distributed_dataset_inputs x_values_list = validate_per_replica_inputs(distribution_strategy, x) File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\keras\\distribute\\distributed_training_utils.py", line 356, in validate_per_replica_inputs validate_all_tensor_shapes(x, x_values) File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\keras\\distribute\\distributed_training_utils.py", line 373, in validate_all_tensor_shapes x_shape = x_values[0].shape.as_list() File "C:\\Users\\gus\\Anaconda3\\envs\\TF2\\lib\\site-packages\\tensorflow_core\\python\\framework\\tensor_shape.py", line 1171, in as_list raise ValueError("as_list() is not defined on an unknown TensorShape.") ValueError: as_list() is not defined on an unknown TensorShape. 1/Unknown - 0s 10ms/step 1/Unknown - 0s 10ms/step Process finished with exit code 1

It seems like it starts to run but then stops because nothing is being produced.它似乎开始运行但随后停止,因为没有产生任何东西。

Using tf.data.Dataset with Keras ImageDataGenerator is a bit tricky.tf.data.Datasettf.data.Dataset ImageDataGenerator使用有点棘手。 You could instead use Keras built in fit_generator method.您可以改为使用内置fit_generator方法的Keras

In order to do so, you could skip this part为此,您可以跳过这一部分

# ds = tf.data.Dataset.from_generator(train_datagen.flow_from_directory,
#             args=[DATA_PATH], output_types=(tf.float32, tf.float32))

and use Keras generator:并使用 Keras 生成器:

train_generator = train_datagen.flow_from_directory(
        DATA_PATH,
        target_size=(150, 150), # or other parameters you need
        batch_size=32,
        class_mode='binary')

finally, the training can be invoked via mentioned fit_generator :最后,可以通过提到的fit_generator调用训练:

model.fit_generator(
        train_generator,
        steps_per_epoch=2000,
        epochs=50,
        validation_data=validation_generator,
        validation_steps=800)

The documentation on this topic is pretty good and I suggest checking it out.关于这个主题的文档非常好,我建议查看一下。 Cheers!干杯!

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 tf2.0:tf.image.resize_with_pad 失败,“使用 `tf.Tensor` 作为 Python `bool” 和 tf.keras.Input - tf2.0: tf.image.resize_with_pad fails with “using a `tf.Tensor` as a Python `bool” with tf.keras.Input TF2.0 自定义生成器的问题 - Issue with TF2.0 Custom Generator 如何使用 TF2.0 (tf.keras) 中的内置 Keras 生成 CNN 热图 - How to generate CNN heatmaps using built-in Keras in TF2.0 (tf.keras) tf2.0 Keras:将自定义张量流代码用于RNN时无法保存权重 - tf2.0 Keras: unable to save weights when using custom tensorflow code for RNN 如何在TF2.0中使用自定义渐变创建keras层? - How to create a keras layer with a custom gradient in TF2.0? 如何在 TF2.0 中用 GradientTape 替换 Keras 的 gradients() function? - How to replace Keras' gradients() function with GradientTape in TF2.0? 在TF2.0中,`tf.keras.Model.compile`有什么作用? - What does `tf.keras.Model.compile` do in TF2.0? tf.data 与 tf.keras.preprocessing.image.ImageDataGenerator - tf.data vs tf.keras.preprocessing.image.ImageDataGenerator 将 Keras Model 应用于符号张量导致 TF2.0 Memory 泄漏 - TF2.0 Memory Leak From Applying Keras Model to Symbolic Tensor 使用 tf2.0/keras 无法获取中间子模型层的 output - Cannot obtain the output of intermediate sub-model layers with tf2.0/keras
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM