繁体   English   中英

VGG Model 的 Output 在训练后变得恒定,并且损失/准确性没有提高

[英]Output of VGG Model becomes constant after training and Loss/Accuracy are not improving

我正在尝试实现一个稍微小一点的 VGG16 版本,并在大约 6000 个图像的数据集上从头开始训练它(5400 个用于训练,600 个用于验证)。 我选择了 30 的批量大小,以便它可以整齐地放入数据集中,否则我会在训练期间得到他的 IncompatibleShape 错误。

经过 15-20 个 epoch 后,EarlyStopping 回调开始并停止训练。

我遇到了这个 model 的两个问题

  1. 之后,当我将测试图像传入 model 时,output 似乎保持不变。 我最不希望的是,对于 imageA,预测的 output 应该与 imageB 不同。 我无法弄清楚为什么会这样
  2. 损失和准确性似乎没有太大变化。 我期望 go 对于 epoch 的数量至少达到 50% 左右,但 go 不会超过 23%。 我试图包括 steps_per_epoch、ReduceLROnPlateau 但它们似乎没有任何影响。

培训 Output:

Epoch 1/50
180/180 [==============================] - 50s 278ms/step - loss: 1.6095 - categorical_accuracy: 0.1987 - val_loss: 1.6109 - val_categorical_accuracy: 0.1267

Epoch 00001: val_loss improved from inf to 1.61094, saving model to vgg16.h5
Epoch 2/50
180/180 [==============================] - 51s 285ms/step - loss: 1.6095 - categorical_accuracy: 0.2044 - val_loss: 1.6107 - val_categorical_accuracy: 0.2133

Epoch 00002: val_loss improved from 1.61094 to 1.61067, saving model to vgg16.h5
Epoch 3/50
180/180 [==============================] - 51s 285ms/step - loss: 1.6098 - categorical_accuracy: 0.1946 - val_loss: 1.6106 - val_categorical_accuracy: 0.1400

Epoch 00003: val_loss improved from 1.61067 to 1.61059, saving model to vgg16.h5
Epoch 4/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6095 - categorical_accuracy: 0.1928 - val_loss: 1.6098 - val_categorical_accuracy: 0.2000

Epoch 00004: val_loss improved from 1.61059 to 1.60983, saving model to vgg16.h5

Epoch 00004: ReduceLROnPlateau reducing learning rate to 2.5000001187436283e-05.
Epoch 5/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6093 - categorical_accuracy: 0.2033 - val_loss: 1.6103 - val_categorical_accuracy: 0.1467

Epoch 00005: val_loss did not improve from 1.60983
Epoch 6/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6094 - categorical_accuracy: 0.1989 - val_loss: 1.6106 - val_categorical_accuracy: 0.1400

Epoch 00006: val_loss did not improve from 1.60983
Epoch 7/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6094 - categorical_accuracy: 0.2069 - val_loss: 1.6098 - val_categorical_accuracy: 0.1733

Epoch 00007: val_loss improved from 1.60983 to 1.60978, saving model to vgg16.h5

Epoch 00007: ReduceLROnPlateau reducing learning rate to 1e-05.
Epoch 8/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6093 - categorical_accuracy: 0.2076 - val_loss: 1.6103 - val_categorical_accuracy: 0.1600

Epoch 00008: val_loss did not improve from 1.60978
Epoch 9/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6095 - categorical_accuracy: 0.2006 - val_loss: 1.6097 - val_categorical_accuracy: 0.2200

Epoch 00009: val_loss improved from 1.60978 to 1.60975, saving model to vgg16.h5
Epoch 10/50
180/180 [==============================] - 52s 287ms/step - loss: 1.6095 - categorical_accuracy: 0.2043 - val_loss: 1.6101 - val_categorical_accuracy: 0.1667

Epoch 00010: val_loss did not improve from 1.60975
Epoch 11/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6094 - categorical_accuracy: 0.2009 - val_loss: 1.6102 - val_categorical_accuracy: 0.1800

Epoch 00011: val_loss did not improve from 1.60975
Epoch 12/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6095 - categorical_accuracy: 0.2041 - val_loss: 1.6115 - val_categorical_accuracy: 0.1600

Epoch 00012: val_loss did not improve from 1.60975
Epoch 13/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6095 - categorical_accuracy: 0.1989 - val_loss: 1.6108 - val_categorical_accuracy: 0.1867

Epoch 00013: val_loss did not improve from 1.60975
Epoch 14/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6094 - categorical_accuracy: 0.2009 - val_loss: 1.6102 - val_categorical_accuracy: 0.1733

Epoch 00014: val_loss did not improve from 1.60975
Epoch 15/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6093 - categorical_accuracy: 0.2074 - val_loss: 1.6113 - val_categorical_accuracy: 0.1467

Epoch 00015: val_loss did not improve from 1.60975
Epoch 16/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6098 - categorical_accuracy: 0.1983 - val_loss: 1.6105 - val_categorical_accuracy: 0.1867

Epoch 00016: val_loss did not improve from 1.60975
Epoch 17/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6095 - categorical_accuracy: 0.2056 - val_loss: 1.6119 - val_categorical_accuracy: 0.1667

Epoch 00017: val_loss did not improve from 1.60975
Epoch 18/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6093 - categorical_accuracy: 0.1994 - val_loss: 1.6110 - val_categorical_accuracy: 0.1800

Epoch 00018: val_loss did not improve from 1.60975
Epoch 19/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6095 - categorical_accuracy: 0.2026 - val_loss: 1.6103 - val_categorical_accuracy: 0.1667

Epoch 00019: val_loss did not improve from 1.60975
Restoring model weights from the end of the best epoch.
Epoch 00019: early stopping

用于获取预测的代码:

predictions = []
actuals=[]

for i, (images, labels) in enumerate( test_datasource):
  if i > 2:
    break
  pred = model_2(images)
  print(labels.shape, pred.shape)
  for j in range(len(labels)):
    actuals.append( labels[j])
    predictions.append(pred[j])
    print(labels[j].numpy(), "\t", pred[j].numpy())

上述代码的Output:

(30, 5) (30, 5)
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
(30, 5) (30, 5)
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
(30, 5) (30, 5)
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]

这是 model 摘要:

Model: "vgg16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_3 (InputLayer)         [(30, 224, 224, 3)]       0         
_________________________________________________________________
conv_1_1 (Conv2D)            (30, 224, 224, 32)        896       
_________________________________________________________________
conv_1_2 (Conv2D)            (30, 224, 224, 32)        9248      
_________________________________________________________________
maxpool_1 (MaxPooling2D)     (30, 112, 112, 32)        0         
_________________________________________________________________
conv_2_1 (Conv2D)            (30, 112, 112, 64)        18496     
_________________________________________________________________
conv_2_2 (Conv2D)            (30, 112, 112, 64)        36928     
_________________________________________________________________
maxpool_2 (MaxPooling2D)     (30, 56, 56, 64)          0         
_________________________________________________________________
conv_3_1 (Conv2D)            (30, 56, 56, 128)         73856     
_________________________________________________________________
conv_3_2 (Conv2D)            (30, 56, 56, 128)         147584    
_________________________________________________________________
conv_3_3 (Conv2D)            (30, 56, 56, 128)         147584    
_________________________________________________________________
maxpool_3 (MaxPooling2D)     (30, 28, 28, 128)         0         
_________________________________________________________________
conv_4_1 (Conv2D)            (30, 28, 28, 256)         295168    
_________________________________________________________________
conv_4_2 (Conv2D)            (30, 28, 28, 256)         590080    
_________________________________________________________________
conv_4_3 (Conv2D)            (30, 28, 28, 256)         590080    
_________________________________________________________________
maxpool_4 (MaxPooling2D)     (30, 14, 14, 256)         0         
_________________________________________________________________
conv_5_1 (Conv2D)            (30, 14, 14, 256)         590080    
_________________________________________________________________
conv_5_2 (Conv2D)            (30, 14, 14, 256)         590080    
_________________________________________________________________
conv_5_3 (Conv2D)            (30, 14, 14, 256)         590080    
_________________________________________________________________
maxpool_5 (MaxPooling2D)     (30, 7, 7, 256)           0         
_________________________________________________________________
flatten (Flatten)            (30, 12544)               0         
_________________________________________________________________
fc_1 (Dense)                 (30, 4096)                51384320  
_________________________________________________________________
dropout_1 (Dropout)          (30, 4096)                0         
_________________________________________________________________
fc_2 (Dense)                 (30, 4096)                16781312  
_________________________________________________________________
dropout_2 (Dropout)          (30, 4096)                0         
_________________________________________________________________
output (Dense)               (30, 5)                   20485     
=================================================================
Total params: 71,866,277
Trainable params: 71,866,277
Non-trainable params: 0

代码在 Google Colab 中: https://colab.research.google.com/drive/1AWe87Zb3MvF90j3RS7sv3OiSgR86q4j_

我尝试了两种版本的 VGG-16,一种是滤镜深度是原来的一半,第二个是滤镜深度的四分之一。

我的印象是,问题在于损失和准确性是不变的,这就是为什么当使用 model 进行预测时,无论输入什么输入,输出都是固定不变的。

当我通过调用模型(输入...,输出...) function 重新初始化 model 并在没有训练的情况下将输入传递给它时,输出至少发生了变化。

我尝试了多种学习率和优化器,model 的行为没有改变。

经过更多的谷歌搜索后,我偶然发现了这些文章:

  1. https://www.quora.com/Why-does-my-convolutional-neural-network-always-produce-the-same-outputs
  2. https://www.quora.com/Why-does-my-own-neural-network-give-me-the-same-output-for-different-input-sets
  3. https://datascience.stackexchange.com/questions/5706/what-is-the-dying-relu-problem-in-neural-networks

我对代码进行了两次更改以使其最终正常工作。

  1. 我最初只是将图像数组除以 255 以使值介于 0 和 1 之间,然后从结果中减去 0.5 以使值介于 -0.5 到 0.5 之间。 这已更改为使用 tf.image.per_image_standardization(images-127) 并将结果除以每个图像中的最大值。 结果,图像值介于 -1 和 +1 之间
  2. 固定输出的另一个主要原因是模型的 relu 单元在训练期间死亡(或饱和)。 relu 激活 function 天生就有这个问题,一旦变量的权重变为 0,它就不会从中恢复。 尽管据说高学习率会导致此问题,但我无法找到缓解此问题的学习率。 另一种解决方案是将激活 function 更改为leaky relu 或 elu (Exponential relu),它们具有从这个问题中恢复的固有机制

随着这些变化,模型的损失下降到 < 1 和训练准确率 <0.90,虽然验证数字不是那么好,但肯定比以前更好

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM