I try to fine tune InceptionV3 model with my custom dataset (consists of 2 classes) but I obtain very low accuracy for both training and validation. What should I do to increase the accuracy? Or do you have other network ideas/implementations for this purpose?
My code:
from keras.datasets import cifar10
from keras.utils import *
from keras.optimizers import SGD
from keras.layers import Input,Dense,Flatten,Dropout,GlobalAveragePooling2D
from keras.preprocessing.image import ImageDataGenerator
from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input
from keras.models import Model
from keras.applications.inception_v3 import InceptionV3
import numpy as np
import cv2
epochs = 10
steps_per_epoch = 300
validation_steps = 300
input_shape=(64, 64, 3)
image_rows=64
image_cols=64
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
'dataset/train',
target_size=(image_rows, image_cols),
batch_size=32,
class_mode='categorical')
validation_generator = test_datagen.flow_from_directory(
'dataset/evaluate',
target_size=(image_rows, image_cols),
batch_size=32,
class_mode='categorical')
inputs = Input(shape=input_shape)
base_model = InceptionV3(weights='imagenet', include_top=False)
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(2, activation='softmax')(x)
model = Model(input=base_model.input, output=predictions)
for layer in base_model.layers:
layer.trainable = False
model.compile(
optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit_generator(
train_generator,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
validation_data=validation_generator,
validation_steps=validation_steps)
Your problem lies in a fact that the according to Keras InceptionV3 documentation - a minimal input size is 139. So - due to the fact that your network input size is 64 - your network doesn't work well. To overcome this issue:
n
, where n > 139
, flow_from_directory
- change the target_size
to (n, n)
.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.