[英]Output of VGG Model becomes constant after training and Loss/Accuracy are not improving
I'm trying to implement a slightly smaller version of VGG16 and train it from scratch on a dataset of about 6000 images (5400 for training and 600 for validation).我正在尝试实现一个稍微小一点的 VGG16 版本,并在大约 6000 个图像的数据集上从头开始训练它(5400 个用于训练,600 个用于验证)。 I chose a batch size of 30 so that it can neatly fit within the dataset, otherwise I would get his with IncompatibleShape error during training.
我选择了 30 的批量大小,以便它可以整齐地放入数据集中,否则我会在训练期间得到他的 IncompatibleShape 错误。
After going through 15-20 epochs, the EarlyStopping callback kicks in and stops the training.经过 15-20 个 epoch 后,EarlyStopping 回调开始并停止训练。
I'm facing two issues with this model我遇到了这个 model 的两个问题
Training Output:培训 Output:
Epoch 1/50
180/180 [==============================] - 50s 278ms/step - loss: 1.6095 - categorical_accuracy: 0.1987 - val_loss: 1.6109 - val_categorical_accuracy: 0.1267
Epoch 00001: val_loss improved from inf to 1.61094, saving model to vgg16.h5
Epoch 2/50
180/180 [==============================] - 51s 285ms/step - loss: 1.6095 - categorical_accuracy: 0.2044 - val_loss: 1.6107 - val_categorical_accuracy: 0.2133
Epoch 00002: val_loss improved from 1.61094 to 1.61067, saving model to vgg16.h5
Epoch 3/50
180/180 [==============================] - 51s 285ms/step - loss: 1.6098 - categorical_accuracy: 0.1946 - val_loss: 1.6106 - val_categorical_accuracy: 0.1400
Epoch 00003: val_loss improved from 1.61067 to 1.61059, saving model to vgg16.h5
Epoch 4/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6095 - categorical_accuracy: 0.1928 - val_loss: 1.6098 - val_categorical_accuracy: 0.2000
Epoch 00004: val_loss improved from 1.61059 to 1.60983, saving model to vgg16.h5
Epoch 00004: ReduceLROnPlateau reducing learning rate to 2.5000001187436283e-05.
Epoch 5/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6093 - categorical_accuracy: 0.2033 - val_loss: 1.6103 - val_categorical_accuracy: 0.1467
Epoch 00005: val_loss did not improve from 1.60983
Epoch 6/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6094 - categorical_accuracy: 0.1989 - val_loss: 1.6106 - val_categorical_accuracy: 0.1400
Epoch 00006: val_loss did not improve from 1.60983
Epoch 7/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6094 - categorical_accuracy: 0.2069 - val_loss: 1.6098 - val_categorical_accuracy: 0.1733
Epoch 00007: val_loss improved from 1.60983 to 1.60978, saving model to vgg16.h5
Epoch 00007: ReduceLROnPlateau reducing learning rate to 1e-05.
Epoch 8/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6093 - categorical_accuracy: 0.2076 - val_loss: 1.6103 - val_categorical_accuracy: 0.1600
Epoch 00008: val_loss did not improve from 1.60978
Epoch 9/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6095 - categorical_accuracy: 0.2006 - val_loss: 1.6097 - val_categorical_accuracy: 0.2200
Epoch 00009: val_loss improved from 1.60978 to 1.60975, saving model to vgg16.h5
Epoch 10/50
180/180 [==============================] - 52s 287ms/step - loss: 1.6095 - categorical_accuracy: 0.2043 - val_loss: 1.6101 - val_categorical_accuracy: 0.1667
Epoch 00010: val_loss did not improve from 1.60975
Epoch 11/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6094 - categorical_accuracy: 0.2009 - val_loss: 1.6102 - val_categorical_accuracy: 0.1800
Epoch 00011: val_loss did not improve from 1.60975
Epoch 12/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6095 - categorical_accuracy: 0.2041 - val_loss: 1.6115 - val_categorical_accuracy: 0.1600
Epoch 00012: val_loss did not improve from 1.60975
Epoch 13/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6095 - categorical_accuracy: 0.1989 - val_loss: 1.6108 - val_categorical_accuracy: 0.1867
Epoch 00013: val_loss did not improve from 1.60975
Epoch 14/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6094 - categorical_accuracy: 0.2009 - val_loss: 1.6102 - val_categorical_accuracy: 0.1733
Epoch 00014: val_loss did not improve from 1.60975
Epoch 15/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6093 - categorical_accuracy: 0.2074 - val_loss: 1.6113 - val_categorical_accuracy: 0.1467
Epoch 00015: val_loss did not improve from 1.60975
Epoch 16/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6098 - categorical_accuracy: 0.1983 - val_loss: 1.6105 - val_categorical_accuracy: 0.1867
Epoch 00016: val_loss did not improve from 1.60975
Epoch 17/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6095 - categorical_accuracy: 0.2056 - val_loss: 1.6119 - val_categorical_accuracy: 0.1667
Epoch 00017: val_loss did not improve from 1.60975
Epoch 18/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6093 - categorical_accuracy: 0.1994 - val_loss: 1.6110 - val_categorical_accuracy: 0.1800
Epoch 00018: val_loss did not improve from 1.60975
Epoch 19/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6095 - categorical_accuracy: 0.2026 - val_loss: 1.6103 - val_categorical_accuracy: 0.1667
Epoch 00019: val_loss did not improve from 1.60975
Restoring model weights from the end of the best epoch.
Epoch 00019: early stopping
Code used to get Predictions:用于获取预测的代码:
predictions = []
actuals=[]
for i, (images, labels) in enumerate( test_datasource):
if i > 2:
break
pred = model_2(images)
print(labels.shape, pred.shape)
for j in range(len(labels)):
actuals.append( labels[j])
predictions.append(pred[j])
print(labels[j].numpy(), "\t", pred[j].numpy())
Output of the above code:上述代码的Output:
(30, 5) (30, 5)
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[1. 0. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[1. 0. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 0. 1.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 0. 1.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[1. 0. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 0. 1.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 0. 1.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[1. 0. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 0. 1.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 1. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 1. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[1. 0. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 1. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[1. 0. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 1. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 1. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
(30, 5) (30, 5)
[0. 1. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 0. 1.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 0. 1.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 1. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 0. 1.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[1. 0. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 1. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 0. 1.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[1. 0. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 0. 1.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 1. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 1. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 1. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 1. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 0. 1.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
(30, 5) (30, 5)
[0. 1. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 0. 1.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 0. 1.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 1. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 1. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[1. 0. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 0. 1.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[1. 0. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 0. 1.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[1. 0. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[1. 0. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[1. 0. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[1. 0. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[1. 0. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[1. 0. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[1. 0. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 1. 0. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 0. 1. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
[0. 0. 1. 0. 0.] [0.19907779 0.20320047 0.1968051 0.20173152 0.19918515]
Here is the model summary:这是 model 摘要:
Model: "vgg16"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_3 (InputLayer) [(30, 224, 224, 3)] 0
_________________________________________________________________
conv_1_1 (Conv2D) (30, 224, 224, 32) 896
_________________________________________________________________
conv_1_2 (Conv2D) (30, 224, 224, 32) 9248
_________________________________________________________________
maxpool_1 (MaxPooling2D) (30, 112, 112, 32) 0
_________________________________________________________________
conv_2_1 (Conv2D) (30, 112, 112, 64) 18496
_________________________________________________________________
conv_2_2 (Conv2D) (30, 112, 112, 64) 36928
_________________________________________________________________
maxpool_2 (MaxPooling2D) (30, 56, 56, 64) 0
_________________________________________________________________
conv_3_1 (Conv2D) (30, 56, 56, 128) 73856
_________________________________________________________________
conv_3_2 (Conv2D) (30, 56, 56, 128) 147584
_________________________________________________________________
conv_3_3 (Conv2D) (30, 56, 56, 128) 147584
_________________________________________________________________
maxpool_3 (MaxPooling2D) (30, 28, 28, 128) 0
_________________________________________________________________
conv_4_1 (Conv2D) (30, 28, 28, 256) 295168
_________________________________________________________________
conv_4_2 (Conv2D) (30, 28, 28, 256) 590080
_________________________________________________________________
conv_4_3 (Conv2D) (30, 28, 28, 256) 590080
_________________________________________________________________
maxpool_4 (MaxPooling2D) (30, 14, 14, 256) 0
_________________________________________________________________
conv_5_1 (Conv2D) (30, 14, 14, 256) 590080
_________________________________________________________________
conv_5_2 (Conv2D) (30, 14, 14, 256) 590080
_________________________________________________________________
conv_5_3 (Conv2D) (30, 14, 14, 256) 590080
_________________________________________________________________
maxpool_5 (MaxPooling2D) (30, 7, 7, 256) 0
_________________________________________________________________
flatten (Flatten) (30, 12544) 0
_________________________________________________________________
fc_1 (Dense) (30, 4096) 51384320
_________________________________________________________________
dropout_1 (Dropout) (30, 4096) 0
_________________________________________________________________
fc_2 (Dense) (30, 4096) 16781312
_________________________________________________________________
dropout_2 (Dropout) (30, 4096) 0
_________________________________________________________________
output (Dense) (30, 5) 20485
=================================================================
Total params: 71,866,277
Trainable params: 71,866,277
Non-trainable params: 0
The code is here in Google Colab: https://colab.research.google.com/drive/1AWe87Zb3MvF90j3RS7sv3OiSgR86q4j_代码在 Google Colab 中: https://colab.research.google.com/drive/1AWe87Zb3MvF90j3RS7sv3OiSgR86q4j_
I tried two versions of VGG-16, one with half the depth of filters than the original and the second with quarter of the depth of filters.我尝试了两种版本的 VGG-16,一种是滤镜深度是原来的一半,第二个是滤镜深度的四分之一。
I was under the impression that the issue was that the loss and accuracy were unchanging and that is why, when using the model for predictions the outputs were fixed and unchanging irrespective of the inputs being fed to it.我的印象是,问题在于损失和准确性是不变的,这就是为什么当使用 model 进行预测时,无论输入什么输入,输出都是固定不变的。
When I reinitialized the model by calling the Model(inputs...,outputs...) function, and passed inputs to it without training, the outputs were at least changing.当我通过调用模型(输入...,输出...) function 重新初始化 model 并在没有训练的情况下将输入传递给它时,输出至少发生了变化。
I tried with multiple learning rates and optimizers with no change in behavior of the model.我尝试了多种学习率和优化器,model 的行为没有改变。
After some more google searching I chanced upon these articles:经过更多的谷歌搜索后,我偶然发现了这些文章:
I made two changes in the code to get it to finally work.我对代码进行了两次更改以使其最终正常工作。
With these changes, the model's loss dropped to < 1 and training-accuracy <0.90 although validation numbers were not this good, but certainly better than before随着这些变化,模型的损失下降到 < 1 和训练准确率 <0.90,虽然验证数字不是那么好,但肯定比以前更好
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.