繁体   English   中英

如何使用 tensorflow 数据集训练神经网络?

[英]How do I train a neural network with tensorflow-datasets?

我正在尝试在 emnist 数据集上训练神经网络,但是当我尝试展平图像时,它会引发以下错误:

警告:tensorflow:Model 是用形状 (None, 28, 28) 构造的,用于输入张量 ("flatten_input:0", shape=(None, 28, 28), dtype=float32) 与输入不兼容形状(无、1、28、28)。

我无法弄清楚似乎是什么问题,并尝试更改我的预处理,从我的 model.fit 和我的 ds.map 中删除批量大小。

这是完整的代码:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

def preprocess(dict):
    image = dict['image']
    image = tf.transpose(image)
    label = dict['label']
    return image, label

train_data, validation_data = tfds.load('emnist/letters', split = ['train', 'test'])
train_data_gen = train_data.map(preprocess).shuffle(1000).batch(32)
validation_data_gen = validation_data.map(preprocess).batch(32)

print(train_data_gen)
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape = (28, 28)),
    tf.keras.layers.Dense(128, activation = 'relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation = 'softmax')
])

model.compile(optimizer = 'adam',
    loss = 'sparse_categorical_crossentropy',
    metrics = ['accuracy'])

early_stopping = keras.callbacks.EarlyStopping(monitor = 'val_accuracy', patience = 10)
history = model.fit(train_data_gen, epochs = 50, batch_size = 32, validation_data = validation_data_gen, callbacks = [early_stopping], verbose = 1)
model.save('emnistmodel.h5')

所以这里实际上发生了一些事情,所以让我们一次解决它们。

  1. 输入形状

    因此,要解决您的直接问题,您会收到不兼容的形状错误,因为输入的形状与预期的形状不匹配。

    在这行tf.keras.layers.Flatten(input_shape=(28, 28)),我们告诉 model 期望输入形状为 (28, 28),但这并不准确。 (as opposed to a colour image which would have 3 channels r, g, and b).我们的输入实际上具有形状 (28, 28, 1),因为我们正在拍摄具有的 28x28 像素图像(而不是具有 3 个通道 r、g 和 b 的彩色图像)。 所以为了解决这个直接的问题,我们只需更新 model 以使用输入的形状。 tf.keras.layers.Flatten(input_shape=(28, 28, 1)),

  2. output节点数

    正如 Rishabh 在他的回答中所建议的那样, EMNIST 数据集有超过 10 个平衡类。 但是,在您的情况下,您似乎正在使用具有 26 个平衡类的 EMNIST Letters。 So your neural net should correspondingly have 27 output nodes (since the class labels go from 1.. 26 while our output nodes correspond to 0.. 26) to be able to classify the given data. 当然,给它额外的 output 节点也可以让它运行,但是这些会给我们额外的训练权重,这将增加我们的 model 所需的训练时间。 总之,你的最后一层应该是tf.keras.layers.Dense(27, activation='softmax')

  3. 预处理 TensorFlow 数据集

    阅读您的 preprocess() function,我相信您正在尝试将训练和验证数据集转换为(图像,标签)的元组。 TensorFlow 不是创建我们自己的 function,而是通过参数as_supervised方便地为我们实现这一点。

    此外,我看到您尝试实现的一些额外预处理,例如对数据进行批处理和洗牌。 再一次,TensorFlow 为我们实现了batch_sizeshuffle_files (见常用参数)! 所以加载数据集看起来像

    train_data, validation_data = tfds.load('emnist/letters', split=['train', 'test'], shuffle_files=True, batch_size=32, as_supervised=True)
  4. 一些附加说明

    此外,作为建议,请考虑从 model.fit() 中排除batch_size 在两个不同的地方定义相同的东西会导致错误和意外行为。 此外,当使用 TensorFlow 数据集时,没有必要,因为它们 已经生成了批次

总体而言,您更新的程序应如下所示

import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
from tensorflow import keras
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


train_data, validation_data = tfds.load('emnist/letters',
                                        split=['train', 'test'],
                                        shuffle_files=True,
                                        batch_size=32,
                                        as_supervised=True)

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(27, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_accuracy', patience=10)

history = model.fit(train_data,
                    epochs=50,
                    validation_data=validation_data,
                    callbacks=[early_stopping],
                    verbose=1)
model.save('emnistmodel.h5')

希望这可以帮助!

嗨 @Rattandeep 我刚刚检查了 emnist 数据集它有 47 个不同的类,在你的密集层中,你提到了 10 个。

如果您将代码从

tf.keras.layers.Dense(10,激活='softmax')

对于这个,它会起作用

tf.keras.layers.Dense(47,激活='softmax')

谢谢

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM