简体   繁体   English

在 RGBA PNG 上训练 Tensorflow 2.0?

[英]Training Tensorflow 2.0 on RGBA PNGs?

I recently upgraded my Colab instance to TensorFlow 2.0 and have attempted to train a Sequential Classification model on a batch of PNG images.我最近将我的 Colab 实例升级到 TensorFlow 2.0 并尝试在一批 PNG 图像上训练顺序分类 model。 I have used the !tf_upgrade_v2 command listed on tensorflow.org to upgrade the script from TensorFlow 1.14 to 2.0 format.我已使用 tensorflow.org 上列出的!tf_upgrade_v2命令将脚本从 TensorFlow 1.14 升级到 2.0 格式。

When I attempt to train the model using model.fit_generator code below I get a re-occurring non-fatal UserWarning line after line as the model runs through each epoch.当我尝试使用下面的model.fit_generator代码训练 model 时,当 Z20F35E630DAF44DBFA4C3F68F5399DepC 运行通过每个 DepC 时,我得到一个重复出现的非致命 UserWarning 行。

# Train the model
history = model.fit_generator(
    train_generator,
    steps_per_epoch=nb_train_samples // batch_size,
    epochs=epochs,
    validation_data=validation_generator,
    validation_steps=nb_validation_samples // batch_size)

An example of the warning is:警告的一个例子是:

1/449 [..............................] - ETA: 12:04 - loss: 3.0403 - accuracy: 0.1250
/usr/local/lib/python3.6/dist-packages/PIL/Image.py:914: UserWarning: Palette images 
with Transparency expressed in bytes should be converted to RGBA images to RGBA
images')

This warning did not appear when using TensorFlow 1.14 when trained on the same batch of PNG images.使用 TensorFlow 1.14 在同一批 PNG 图像上进行训练时,不会出现此警告。 The TF2.0 model does continue to train, however it is slowing the training process significantly as it keeps printing to the warning. TF2.0 model 确实会继续训练,但是它会显着减慢训练过程,因为它会一直打印到警告。

I have gone back to the dataset and tried converting the PNGs ensuring that they are in RGB format not RGBA using the example listed here however this has failed to resolve the problem.我已经返回数据集并尝试使用此处列出的示例转换 PNG,确保它们为 RGB 格式而不是 RGBA,但这未能解决问题。

After further research I found further details about the ImageDataGenerator and specifically the flow_from_dataframe method in the TensorFlow Core r2.0 documentation .经过进一步研究,我在TensorFlow Core r2.0 文档中找到了有关 ImageDataGenerator 的更多详细信息,特别是flow_from_dataframe方法。

The UserWarning described can be resolved by adding the color_mode='rgba' parameter to the flow_from_dataframe method.可以通过将color_mode='rgba'参数添加到flow_from_dataframe方法来解决所描述的 UserWarning。

For example:例如:

from tensorflow.python.keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    rescale=1. / 255,
    shear_range=0.1,
    zoom_range=0.2,
    horizontal_flip=False,
    fill_mode='nearest')

train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(150, 150),
    batch_size=16,
    save_format='png',
    class_mode='sparse',
    color_mode='rgba')

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM