简体   繁体   English

为 ResNet50 重塑 MNIST

[英]Reshaping MNIST for ResNet50

I am trying to train the mnist dataset on ResNet50 using the Keras library.我正在尝试使用Keras库在 ResNet50 上训练 mnist 数据集。 The shape of mnist is (28, 28, 1) however resnet50 required the shape to be (32, 32, 3) mnist 的形状是(28, 28, 1)但是 resnet50 要求的形状是(32, 32, 3)

How can I convert the mnist dataset to the required shape?如何将 mnist 数据集转换为所需的形状?

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], x_train.shape[2], 1)
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1], x_test.shape[2], 1)
x_train = x_train/255.0
x_test = x_test/255.0
from keras.utils import to_categorical
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
model = models.Sequential()
# model.add(InputLayer(input_shape=(28, 28)))
# model.add(Reshape(target_shape=(32, 32, 3)))
# model.add(Conv2D())
model.add(conv_base)
model.add(Flatten())
model.add(BatchNormalization())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(BatchNormalization())
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
model.add(BatchNormalization())
model.add(Dense(10, activation='softmax'))

model.compile(optimizer=optimizers.RMSprop(lr=2e-5), loss='binary_crossentropy', metrics=['acc'])

history = model.fit(x_train, y_train, epochs=5, batch_size=20, validation_data=(x_test, y_test))
ValueError: Input 0 is incompatible with layer sequential_10: expected shape=(None, 32, 32, 3), found shape=(20, 28, 28, 1) 

You need to resize the MNIST data set.您需要调整MNIST数据集的大小。 Note that minimum size actually depends on the ImageNet model.请注意,最小尺寸实际上取决于 ImageNet model。 For example: Xception requires at least 72 , where ResNet is asking for 32 .例如: Xception至少需要72 ,而ResNet要求32 Apart from that, the MNIST is a grayscale image, but it may conflict if you're using the pretrained weight of these models.除此之外, MNIST是灰度图像,但如果您使用这些模型的预训练权重,它可能会发生冲突。 So, good and safe side is to resize and convert grayscale to RGB .所以,好的和安全的一面是调整灰度并将灰度转换为RGB


Full working code for you.为您提供完整的工作代码。

Data Set数据集

We will resize MNIST from 28 to 32. Also, make 3 channels instead of keeping 1.我们会将MNIST的大小从 28 调整为 32。此外,制作 3 个通道而不是保留 1 个。

import tensorflow as tf 
import numpy as np 

(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()

# expand new axis, channel axis 
x_train = np.expand_dims(x_train, axis=-1)

# [optional]: we may need 3 channel (instead of 1)
x_train = np.repeat(x_train, 3, axis=-1)

# it's always better to normalize 
x_train = x_train.astype('float32') / 255

# resize the input shape , i.e. old shape: 28, new shape: 32
x_train = tf.image.resize(x_train, [32,32]) # if we want to resize 

# one hot 
y_train = tf.keras.utils.to_categorical(y_train , num_classes=10)

print(x_train.shape, y_train.shape)
(60000, 32, 32, 3) (60000, 10)

ResNet 50资源网 50

input = tf.keras.Input(shape=(32,32,3))
efnet = tf.keras.applications.ResNet50(weights='imagenet',
                                             include_top = False, 
                                             input_tensor = input)
# Now that we apply global max pooling.
gap = tf.keras.layers.GlobalMaxPooling2D()(efnet.output)

# Finally, we add a classification layer.
output = tf.keras.layers.Dense(10, activation='softmax', use_bias=True)(gap)

# bind all
func_model = tf.keras.Model(efnet.input, output)

Train火车

func_model.compile(
          loss  = tf.keras.losses.CategoricalCrossentropy(),
          metrics = tf.keras.metrics.CategoricalAccuracy(),
          optimizer = tf.keras.optimizers.Adam())
# fit 
func_model.fit(x_train, y_train, batch_size=128, epochs=5, verbose = 2)
Epoch 1/5
469/469 - 56s - loss: 0.1184 - categorical_accuracy: 0.9690
Epoch 2/5
469/469 - 21s - loss: 0.0648 - categorical_accuracy: 0.9844
Epoch 3/5
469/469 - 21s - loss: 0.0503 - categorical_accuracy: 0.9867
Epoch 4/5
469/469 - 21s - loss: 0.0416 - categorical_accuracy: 0.9888
Epoch 5/5
469/469 - 21s - loss: 0.1556 - categorical_accuracy: 0.9697
<tensorflow.python.keras.callbacks.History at 0x7f316005a3d0>

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM