So I have done data augmentation in a keras model. I am using Fashin_Mnist dataset. Everything goes okay but when it finished the first epoch it throws an error.
The error: ValueError: Shapes (32, 1) and (32, 10) are incompatible
My data:
img_rows = 28
img_cols = 28
batch_size = 512
img_shape = (img_rows, img_cols, 1)
x_train = x_train.reshape(x_train.shape[0], *img_shape)
x_test = x_test.reshape(x_test.shape[0], *img_shape)
x_val = x_val.reshape(x_val.shape[0], *img_shape)
label_as_binary = LabelBinarizer()
y_train_binary = label_as_binary.fit_transform(y_train)
y_test_binary = label_as_binary.fit_transform(y_test)
y_val_binary = label_as_binary.fit_transform(y_val)
My model:
model2 = Sequential([
Conv2D(filters=32, kernel_size=3, activation='relu',
input_shape=img_shape, padding="same"),
MaxPooling2D(pool_size=2),
Conv2D(filters=32, kernel_size=3, activation='relu',
padding="same"),
MaxPooling2D(pool_size=2),
Dropout(0.25),
Conv2D(filters=64, kernel_size=3, activation='relu',
padding="same"),
MaxPooling2D(pool_size=2),
Conv2D(filters=64, kernel_size=3, activation='relu',
padding="same"),
MaxPooling2D(pool_size=2),
Dropout(0.25),
Flatten(),
Dense(512, activation='relu'),
Dense(10, activation='softmax')
])
The data augmentation:
datagen = ImageDataGenerator(horizontal_flip=True, rotation_range=45,
width_shift_range=0.2, height_shift_range=0.2, zoom_range=0.1)
datagen.fit(x_train)
for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=9):
for i in range(0, 9):
pyplot.subplot(330 + 1 + i)
pyplot.imshow(x_batch[i].reshape(28, 28),
cmap=pyplot.get_cmap('gray'))
pyplot.show()
break
model2.compile(loss='categorical_crossentropy',
optimizer=Adadelta(learning_rate=0.01), metrics=['accuracy'])
history = model2.fit_generator(datagen.flow(x_train,y_train_binary,
batch_size=batch_size),
epochs = 10, validation_data = (x_train, y_val_binary), verbose = 1)
I have seen many similar answers but none of them seem to fit mine. Help is much appreciated.
I think you should change this line:
validation_data = (x_train, y_val_binary)
to this:
validation_data = (x_val, y_val_binary)
Then, your model should run properly.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.