繁体   English   中英

高精度训练但低精度测试/预测

[英]High accuracy training but low accuracy test/prediction

我正在使用 CNN 对苹果类型进行分类。 我在火车数据上取得了很高的准确性,但在测试数据上的准确性却很低。 数据分成 80:20。 我不确定我的数据是否过度拟合。

我有 2 个包含TraningDataTestData的文件夹,每个文件夹有4个子文件夹braeburn, red_apple, red_delicious, rotten (包含相应的图片)。

TRAIN_DIR = 'apple_fruit'
TEST_DIR = 'apple_fruit'
classes = ['braeburn','red_apples','red_delicious','rotten'] train_datagen = ImageDataGenerator(rescale = 1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode='nearest')
 
test_datagen = ImageDataGenerator(rescale = 1./255) 

training_set = train_datagen.flow_from_directory(TRAIN_DIR,
shuffle=True,
target_size = (100,100),
batch_size = 25,
classes =['braeburn','red_apples','red_delicious','rotten'])

test_set= test_datagen.flow_from_directory(TEST_DIR,
target_size = (100, 100),
shuffle=True,
 batch_size = 25,classes = classes)

model =Sequential()
model.add(Conv2D(filters=128, kernel_size=(3,3),input_shape=(100,100,3), activation='relu', padding
= 'same'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(filters=16, kernel_size=(3,3), activation='relu', padding = 'same'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(256))
model.add(Activation('relu'))
model.add(Dropout(0.6))
model.add(Dense(4,activation='softmax'))
model.compile(optimizer ='adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])

history = model.fit(x=training_set,#y=training_set.labels,
steps_per_epoch=len(training_set),
epochs =10)

model.save('Ripe2_model6.h5')  # creates a HDF5 file 'my_model.h5'

model_path = "Ripe2_model6.h5"
loaded_model = keras.models.load_model(model_path)
classes = ['braeburn','red_apples','red_delicious','rotten']
predictions = model.predict(x=test_set, steps=len(test_set), verbose=True)
pred = np.round(predictions)

y_true=test_set.classes
y_pred=np.argmax(pred, axis=-1)
    > cm = confusion_matrix(y_true=test_set.classes, y_pred=np.argmax(pred, axis=-1))
test_set.classes
np.argmax(pred, axis=-1)
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):

accuracy = np.trace(cm) / float(np.sum(cm))
misclass = 1 - accuracy

  """
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
    """
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title,color = 'white')
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45,color = 'white')
plt.yticks(tick_marks, classes,color = 'white')
target_names = ['braeburn','red_apples','red_delicious','rotten']

if target_names is not None:
 tick_marks = np.arange(len(target_names))
 plt.xticks(tick_marks, target_names, rotation=45)
 plt.yticks(tick_marks, target_names)

if normalize:
 cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
 thresh = cm.max() / 1.5 if normalize else cm.max() / 2
 for i, j in itertools.product(range(cm.shape[0]), 
  range(cm.shape[1])):
  if normalize:
   plt.text(j, i, "{:0.4f}".format(cm[i, j]),
   horizontalalignment="center",
   color="white" if cm[i, j] > thresh else "black")
  else:
   plt.text(j, i, "{:,}".format(cm[i, j]),
   horizontalalignment="center",
   color="white" if cm[i, j] > thresh else "black")

plt.tight_layout()
plt.ylabel('True label',color = 'white')
plt.xlabel('Predicted label',color = 'white')

cm_plot_labels = ['braeburn','red_apples','red_delicious','rotten']
plot_confusion_matrix(cm=cm, classes=cm_plot_labels, title='Confusion Matrix')

print(accuracy_score(y_true, y_pred))
print(recall_score(y_true, y_pred, average=None))
print(precision_score(y_true, y_pred, average=None))

混淆矩阵:

  • 精度 - 0.2909090909090909
  • 召回 - [0.23484848 0.32319392 0.15151515 0.36213992]
  • 精度 - [0.23308271 0.32319392 0.15151515 0.36363636]

我尝试更改许多功能,但仍然没有进展。

这表明测试集中的数据与模型学习到的数据有很大不同。 要了解它是否过度拟合或一次不幸的分裂:

  1. 检查您的结果是否依赖于初始的 Train/Test split 为此,您可以:
  • [可选] 将所有图片合并到整个数据集(train+test)文件夹中。
  • 将图像随机分割成训练/测试(而不是使用初始分割)
  • 实施交叉验证(例如 K-Fold)
  1. 你有足够数量的样本吗? 尝试添加更多示例并检查它如何影响性能。 您还可以应用数据增强技术。

如果训练数据的准确率很好,但测试数据的准确率很低,那么模型就容易过拟合。 原因可能是一个简单的数据集,其中模型试图捕获包括噪声在内的所有数据点。 在上述情况下,尝试调整参数并设置更高的批次,执行交叉验证以了解性能并执行数据扩充。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM