簡體   English   中英

如何使用 keras 正確微調 model?

[英]How to correctly fine tune a model using keras?

I am trying to fine tune VGG 16 model in order to predict 12 class, the class I have are different that imagenet, this is why am trying to fine tune, the model produces the below output with very high loss and low accracuy during first epoch . 我只是認為到第 10 個 epoch 結束時,准確性不會達到統計水平。 我想知道這是正常的還是做錯了什么

進步OUTPUT

1/362 [..............................] - ETA: 17:00 - loss: 4.2610 - acc: 0.0625
  2/362 [..............................] - ETA: 16:06 - loss: 381.0046 - acc: 0.0312
  3/362 [..............................] - ETA: 15:44 - loss: 298.8458 - acc: 0.0208
  4/362 [..............................] - ETA: 15:35 - loss: 226.6889 - acc: 0.0156
  5/362 [..............................] - ETA: 15:29 - loss: 182.2427 - acc: 0.1125
  6/362 [..............................] - ETA: 15:22 - loss: 160.5883 - acc: 0.0938
  7/362 [..............................] - ETA: 15:17 - loss: 138.1007 - acc: 0.1562
  8/362 [..............................] - ETA: 15:13 - loss: 121.4596 - acc: 0.1367
  9/362 [..............................] - ETA: 15:10 - loss: 108.2340 - acc: 0.1285
 10/362 [..............................] - ETA: 15:05 - loss: 97.6966 - acc: 0.1156 
 11/362 [..............................] - ETA: 15:02 - loss: 89.1747 - acc: 0.1080
 12/362 [..............................] - ETA: 14:58 - loss: 82.0693 - acc: 0.1016
 13/362 [>.............................] - ETA: 14:54 - loss: 75.9421 - acc: 0.0962
 14/362 [>.............................] - ETA: 14:50 - loss: 70.7751 - acc: 0.0915
 15/362 [>.............................] - ETA: 14:47 - loss: 66.2048 - acc: 0.0875
 16/362 [>.............................] - ETA: 14:42 - loss: 62.3336 - acc: 0.0820
 17/362 [>.............................] - ETA: 14:37 - loss: 58.7955 - acc: 0.0809
 18/362 [>.............................] - ETA: 14:32 - loss: 55.6291 - acc: 0.0990
 19/362 [>.............................] - ETA: 14:27 - loss: 53.4754 - acc: 0.0938
 20/362 [>.............................] - ETA: 14:22 - loss: 50.9502 - acc: 0.0922
 21/362 [>.............................] - ETA: 14:18 - loss: 51.2024 - acc: 0.0878
 22/362 [>.............................] - ETA: 14:14 - loss: 49.6734 - acc: 0.0838
 23/362 [>.............................] - ETA: 14:11 - loss: 47.6672 - acc: 0.0802
 24/362 [>.............................] - ETA: 14:07 - loss: 45.8416 - acc: 0.0807
 25/362 [=>............................] - ETA: 14:03 - loss: 44.1554 - acc: 0.0887
 26/362 [=>............................] - ETA: 14:01 - loss: 42.5939 - acc: 0.0853
 27/362 [=>............................] - ETA: 13:59 - loss: 41.5915 - acc: 0.0822
 28/362 [=>............................] - ETA: 13:55 - loss: 40.6199 - acc: 0.0904
 29/362 [=>............................] - ETA: 13:52 - loss: 39.3127 - acc: 0.0873
 30/362 [=>............................] - ETA: 13:49 - loss: 38.1391 - acc: 0.0844
 31/362 [=>............................] - ETA: 13:47 - loss: 37.0256 - acc: 0.0837
 32/362 [=>............................] - ETA: 13:44 - loss: 35.9840 - acc: 0.0820
 33/362 [=>............................] - ETA: 13:42 - loss: 34.9967 - acc: 0.0795
 34/362 [=>............................] - ETA: 13:39 - loss: 34.0358 - acc: 0.0846
 35/362 [=>............................] - ETA: 13:37 - loss: 33.1187 - acc: 0.1000
 36/362 [=>............................] - ETA: 13:34 - loss: 32.2891 - acc: 0.0998
 37/362 [==>...........................] - ETA: 13:32 - loss: 31.4585 - acc: 0.1115
 38/362 [==>...........................] - ETA: 13:29 - loss: 30.7113 - acc: 0.1086
 39/362 [==>...........................] - ETA: 13:27 - loss: 30.3106 - acc: 0.1074
 40/362 [==>...........................] - ETA: 13:24 - loss: 29.6147 - acc: 0.1047
 41/362 [==>...........................] - ETA: 13:22 - loss: 28.9527 - acc: 0.1052
 42/362 [==>...........................] - ETA: 13:20 - loss: 28.3151 - acc: 0.1042
 43/362 [==>...........................] - ETA: 13:18 - loss: 27.8628 - acc: 0.1017
 44/362 [==>...........................] - ETA: 13:16 - loss: 27.3071 - acc: 0.0994
 45/362 [==>...........................] - ETA: 13:14 - loss: 26.7912 - acc: 0.0972
 46/362 [==>...........................] - ETA: 13:12 - loss: 26.2834 - acc: 0.0992
 47/362 [==>...........................] - ETA: 13:09 - loss: 26.1134 - acc: 0.0971
 48/362 [==>...........................] - ETA: 13:07 - loss: 25.6429 - acc: 0.0951
 49/362 [===>..........................] - ETA: 13:04 - loss: 25.2244 - acc: 0.0931
 50/362 [===>..........................] - ETA: 13:02 - loss: 25.3814 - acc: 0.0969
 51/362 [===>..........................] - ETA: 12:59 - loss: 24.9342 - acc: 0.0956
 52/362 [===>..........................] - ETA: 12:56 - loss: 24.5366 - acc: 0.0938
 53/362 [===>..........................] - ETA: 12:54 - loss: 24.1119 - acc: 0.0991
 54/362 [===>..........................] - ETA: 12:51 - loss: 23.7437 - acc: 0.0972
 55/362 [===>..........................] - ETA: 12:48 - loss: 23.3581 - acc: 0.0960
 56/362 [===>..........................] - ETA: 12:46 - loss: 23.1098 - acc: 0.0943
 57/362 [===>..........................] - ETA: 12:43 - loss: 22.7660 - acc: 0.0927
 58/362 [===>..........................] - ETA: 12:41 - loss: 22.4216 - acc: 0.0911
 59/362 [===>..........................] - ETA: 12:38 - loss: 22.1043 - acc: 0.0895
 60/362 [===>..........................] - ETA: 12:36 - loss: 21.7761 - acc: 0.0938
 61/362 [====>.........................] - ETA: 12:33 - loss: 21.4751 - acc: 0.0922
 62/362 [====>.........................] - ETA: 12:30 - loss: 21.2513 - acc: 0.0907
 63/362 [====>.........................] - ETA: 12:27 - loss: 20.9492 - acc: 0.0918
 64/362 [====>.........................] - ETA: 12:24 - loss: 20.6726 - acc: 0.0903
 65/362 [====>.........................] - ETA: 12:21 - loss: 20.3861 - acc: 0.0923
 66/362 [====>.........................] - ETA: 12:18 - loss: 20.1427 - acc: 0.0909
 67/362 [====>.........................] - ETA: 12:15 - loss: 19.8937 - acc: 0.0905
 68/362 [====>.........................] - ETA: 12:12 - loss: 19.6332 - acc: 0.0901
 69/362 [====>.........................] - ETA: 12:09 - loss: 19.4407 - acc: 0.0888
 70/362 [====>.........................] - ETA: 12:06 - loss: 19.2126 - acc: 0.0875
 71/362 [====>.........................] - ETA: 12:03 - loss: 18.9823 - acc: 0.0893
 72/362 [====>.........................] - ETA: 12:01 - loss: 18.7506 - acc: 0.0885
 73/362 [=====>........................] - ETA: 11:58 - loss: 18.5105 - acc: 0.0967
 74/362 [=====>........................] - ETA: 11:55 - loss: 18.2603 - acc: 0.1090
 75/362 [=====>........................] - ETA: 11:52 - loss: 18.0168 - acc: 0.1208
 76/362 [=====>........................] - ETA: 11:49 - loss: 18.4406 - acc: 0.1192
 77/362 [=====>........................] - ETA: 11:46 - loss: 18.2507 - acc: 0.1185
 78/362 [=====>........................] - ETA: 11:43 - loss: 18.0736 - acc: 0.1170
 79/362 [=====>........................] - ETA: 11:41 - loss: 17.8874 - acc: 0.1163
 80/362 [=====>........................] - ETA: 11:38 - loss: 17.7183 - acc: 0.1148
 81/362 [=====>........................] - ETA: 11:35 - loss: 17.5281 - acc: 0.1154
 82/362 [=====>........................] - ETA: 11:32 - loss: 17.3660 - acc: 0.1139
 83/362 [=====>........................] - ETA: 11:29 - loss: 17.1998 - acc: 0.1126
 84/362 [=====>........................] - ETA: 11:26 - loss: 17.0346 - acc: 0.1135
 85/362 [======>.......................] - ETA: 11:23 - loss: 17.0110 - acc: 0.1132
 86/362 [======>.......................] - ETA: 11:20 - loss: 16.9285 - acc: 0.1119
 87/362 [======>.......................] - ETA: 11:18 - loss: 16.7656 - acc: 0.1171
 88/362 [======>.......................] - ETA: 11:15 - loss: 16.8805 - acc: 0.1158
 89/362 [======>.......................] - ETA: 11:13 - loss: 16.7179 - acc: 0.1145
 90/362 [======>.......................] - ETA: 11:10 - loss: 16.6223 - acc: 0.1135
 91/362 [======>.......................] - ETA: 11:08 - loss: 16.4856 - acc: 0.1123
 92/362 [======>.......................] - ETA: 11:06 - loss: 16.3405 - acc: 0.1131
 93/362 [======>.......................] - ETA: 11:03 - loss: 16.2040 - acc: 0.1119
 94/362 [======>.......................] - ETA: 11:00 - loss: 16.0733 - acc: 0.1107
 95/362 [======>.......................] - ETA: 10:58 - loss: 15.9401 - acc: 0.1109
 96/362 [======>.......................] - ETA: 10:55 - loss: 15.8051 - acc: 0.1136
 97/362 [=======>......................] - ETA: 10:53 - loss: 15.6705 - acc: 0.1169
 98/362 [=======>......................] - ETA: 10:50 - loss: 15.5265 - acc: 0.1234
 99/362 [=======>......................] - ETA: 10:48 - loss: 15.4323 - acc: 0.1222
100/362 [=======>......................] - ETA: 10:46 - loss: 15.3164 - acc: 0.1209
101/362 [=======>......................] - ETA: 10:43 - loss: 15.2053 - acc: 0.1197
102/362 [=======>......................] - ETA: 10:41 - loss: 15.0870 - acc: 0.1192
103/362 [=======>......................] - ETA: 10:38 - loss: 14.9731 - acc: 0.1186
104/362 [=======>......................] - ETA: 10:36 - loss: 14.8524 - acc: 0.1175
105/362 [=======>......................] - ETA: 10:34 - loss: 14.7380 - acc: 0.1208
106/362 [=======>......................] - ETA: 10:31 - loss: 14.6252 - acc: 0.1197
107/362 [=======>......................] - ETA: 10:29 - loss: 14.5199 - acc: 0.1192
108/362 [=======>......................] - ETA: 10:26 - loss: 14.4181 - acc: 0.1183
109/362 [========>.....................] - ETA: 10:24 - loss: 14.3249 - acc: 0.1173
110/362 [========>.....................] - ETA: 10:21 - loss: 14.2265 - acc: 0.1162
111/362 [========>.....................] - ETA: 10:18 - loss: 14.1234 - acc: 0.1151
112/362 [========>.....................] - ETA: 10:16 - loss: 14.0295 - acc: 0.1150
113/362 [========>.....................] - ETA: 10:14 - loss: 13.9287 - acc: 0.1139
114/362 [========>.....................] - ETA: 10:11 - loss: 13.8335 - acc: 0.1132
115/362 [========>.....................] - ETA: 10:09 - loss: 13.7322 - acc: 0.1174
116/362 [========>.....................] - ETA: 10:06 - loss: 13.6475 - acc: 0.1164
117/362 [========>.....................] - ETA: 10:04 - loss: 13.6022 - acc: 0.1154
118/362 [========>.....................] - ETA: 10:01 - loss: 13.5110 - acc: 0.1144

MODEL

image_size = 150

train_batchsize = 32
val_batchsize = 32
class_weights = {}

from collections import Counter


def create_model():
    vgg_conv = VGG16(include_top=False, weights='imagenet', input_shape=(image_size, image_size, 3))

    # Freeze the layers except the last 4 layers
    for layer in vgg_conv.layers[:-4]:
        layer.trainable = False

    model = Sequential()
    model.add(vgg_conv)
    # model.add(Flatten())
    model.add(GlobalAveragePooling2D())
    model.add(Dense(256, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(12, activation='softmax'))
    return model


def train_top_model():
    # STEP 1 : GENERATE MODEL

    model = create_model()
    # Show a summary of the model. Check the number of trainable parameters
    model.summary()

    data_gen = ImageDataGenerator(
        rescale=1. / 255,
    )


    train_generator = data_gen.flow_from_directory(
        train_data_dir,
        batch_size=train_batchsize,
        target_size=(image_size, image_size),
        class_mode='categorical',
        shuffle=False)

    validation_generator = data_gen.flow_from_directory(
        validation_data_dir,
        batch_size=val_batchsize,
        target_size=(image_size, image_size),
        class_mode='categorical',
        shuffle=False)

    class_weights = get_class_weights(train_generator)

    # Compile the model
    model.compile(loss='categorical_crossentropy',
                  optimizer=k.optimizers.RMSprop(lr=1e-4),
                  metrics=['acc'])

    # Train the model

    history = model.fit_generator(
        train_generator,
        steps_per_epoch=train_generator.samples / train_generator.batch_size,
        epochs=10,
        validation_data=validation_generator,
        validation_steps=validation_generator.samples / validation_generator.batch_size,
        class_weight=class_weights,
        verbose=1)

快速回答:您應該替換rescale=1. / 255 rescale=1. / 255preprocessing_function=lambda x: x - np.array([103.939, 116.779, 123.68]) 您的代碼的輸入規范化不正確。 更多信息,請查看Keras VGG16 preprocess_input 模式


讓我列出在微調/遷移學習期間應該尋找什么的清單。 突出顯示的項目是導致您的問題的項目:

  • 確保輸入比例正確。
  • 手動查看 model 訓練的數據集,以及它是否與您嘗試學習的數據集相似。
  • 確保凍結底層。 上層的大梯度會破壞底層的學習特征。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM