簡體   English   中英

我未能訓練 CNN + LSTM model。 我怎么解決這個問題? 數據集有問題嗎? 還是 model? (Python 3.8x)

[英]I failed to train CNN + LSTM model. How can I solve this problem? Is it have problem in dataset? or model? (Python 3.8x)

0. 我用過:

  • Python 3.8x

  • JupyterLab >=3.0

  • Tensorflow

  • Keras

  • VGG19(預訓練模型)

1. 我的問題

我嘗試將 CNN + LSTM Python model 訓練到視頻分類(二進制分類)。

但是......我沒有訓練我的 model。 我的 JupyterLab(>=3.0) 只打印了Epoch 1/100並且幾乎停止了,或者重新啟動了內核(我建議可能 memory 不夠,但我的桌面有 16GB RAM。)。

我弄錯了 model 嗎? 還是我的數據集有問題?

另外,有時我減少了訓練數據的大小。(2000 -> 100)但問題沒有解決。

這是我的 model 和數據集的結構。

2.輸入數據形狀(我的數據集)

數據:data_training_ar

  • 類型:numpy 陣列
  • 形狀:(2697、30、160、160、3)

它有 2697 個視頻的 160*160 大小的 RGB ndarray。 每個視頻有 30 幀。

  • 示例:data_training_ar[10]
array([[[[0.03105 , 0.02397 , 0.02713 ],
         [0.08167 , 0.0738  , 0.0777  ],
         [0.1142  , 0.1064  , 0.1103  ],
         ...,
         [0.183   , 0.1752  , 0.1713  ],
         [0.12427 , 0.11646 , 0.1137  ],
         [0.01765 , 0.0098  , 0.00784 ]],

        [[0.1113  , 0.1051  , 0.1074  ],
         [0.5225  , 0.5146  , 0.5186  ],
         [0.3794  , 0.3713  , 0.3755  ],
         ...,
         [0.2229  , 0.2151  , 0.2112  ],
         [0.1255  , 0.1177  , 0.1137  ],
         [0.013725, 0.00816 , 0.005882]],

        [[0.124   , 0.11615 , 0.1201  ],
         [0.4556  , 0.4478  , 0.4517  ],
         [0.3982  , 0.3904  , 0.3943  ],
         ...,
         [0.1613  , 0.1534  , 0.1495  ],
         [0.1173  , 0.10956 , 0.1075  ],
         [0.0098  , 0.005882, 0.005882]],

        ...,

        [[0.08453 , 0.08246 , 0.08246 ],
         [0.4902  , 0.498   , 0.4863  ],
         [0.5337  , 0.5728  , 0.5337  ],
         ...,
         [0.9883  , 0.9883  , 0.9883  ],
         [0.9883  , 0.9883  , 0.9883  ],
         [0.9883  , 0.9883  , 0.9883  ]],

        [[0.08234 , 0.0807  , 0.08466 ],
         [0.482   , 0.4941  , 0.4883  ],
         [0.51    , 0.554   , 0.521   ],
         ...,
         [0.9766  , 0.9766  , 0.9766  ],
         [0.9663  , 0.9663  , 0.9663  ],
         [0.9683  , 0.9683  , 0.9683  ]],

        [[0.08234 , 0.0843  , 0.0863  ],
         [0.4824  , 0.496   , 0.4902  ],
         [0.51    , 0.551   , 0.5195  ],
         ...,
         [0.4133  , 0.4133  , 0.4133  ],
         [0.3955  , 0.3955  , 0.3955  ],
         [0.3523  , 0.3523  , 0.3523  ]]],


       [[[0.01689 , 0.01221 , 0.01296 ],
         [0.0955  , 0.08765 , 0.09155 ],
         [0.1139  , 0.1061  , 0.11    ],
         ...,
         [0.179   , 0.1711  , 0.1672  ],
         [0.12354 , 0.11566 , 0.11255 ],
         [0.01645 , 0.0098  , 0.0098  ]],

        [[0.11365 , 0.10583 , 0.10974 ],
         [0.5186  , 0.5107  , 0.5146  ],
         [0.3809  , 0.373   , 0.377   ],
         ...,
         [0.232   , 0.2242  , 0.2203  ],
         [0.1232  , 0.11566 , 0.11176 ],
         [0.013725, 0.0098  , 0.00784 ]],

        [[0.135   , 0.1274  , 0.1311  ],
         [0.4604  , 0.4526  , 0.4565  ],
         [0.3862  , 0.3784  , 0.3823  ],
         ...,
         [0.1727  , 0.1648  , 0.1609  ],
         [0.11115 , 0.10333 , 0.09937 ],
         [0.013725, 0.00784 , 0.005882]],

        ...,

        [[0.07855 , 0.0787  , 0.0745  ],
         [0.4788  , 0.4963  , 0.4785  ],
         [0.5337  , 0.563   , 0.5317  ],
         ...,
         [0.9883  , 0.9883  , 0.9883  ],
         [0.9883  , 0.9883  , 0.9883  ],
         [0.9883  , 0.9883  , 0.9883  ]],

        [[0.0745  , 0.0804  , 0.0784  ],
         [0.4727  , 0.496   , 0.4805  ],
         [0.5137  , 0.551   , 0.5254  ],
         ...,
         [0.9717  , 0.9717  , 0.9717  ],
         [0.974   , 0.974   , 0.974   ],
         [0.973   , 0.973   , 0.973   ]],

        [[0.0745  , 0.08234 , 0.0804  ],
         [0.4727  , 0.498   , 0.4844  ],
         [0.5137  , 0.551   , 0.5254  ],
         ...,
         [0.4067  , 0.4067  , 0.4067  ],
         [0.3923  , 0.3923  , 0.3923  ],
         [0.3586  , 0.3586  , 0.3586  ]]],


       [[[0.01689 , 0.01025 , 0.01296 ],
         [0.09265 , 0.07965 , 0.0836  ],
         [0.12445 , 0.1053  , 0.11    ],
         ...,
         [0.172   , 0.1674  , 0.1635  ],
         [0.111   , 0.1149  , 0.10706 ],
         [0.00784 , 0.008606, 0.00784 ]],

        [[0.1068  , 0.0996  , 0.1029  ],
         [0.522   , 0.5117  , 0.5156  ],
         [0.3933  , 0.3755  , 0.3813  ],
         ...,
         [0.2363  , 0.2305  , 0.2249  ],
         [0.1209  , 0.1213  , 0.1134  ],
         [0.00784 , 0.00948 , 0.00784 ]],

        [[0.1294  , 0.1239  , 0.1257  ],
         [0.4658  , 0.4563  , 0.46    ],
         [0.395   , 0.3796  , 0.3835  ],
         ...,
         [0.1705  , 0.1627  , 0.1588  ],
         [0.1207  , 0.11676 , 0.111   ],
         [0.00968 , 0.00968 , 0.005882]],

        ...,

        [[0.0726  , 0.0784  , 0.0727  ],
         [0.471   , 0.4963  , 0.4749  ],
         [0.528   , 0.565   , 0.5317  ],
         ...,
         [0.9883  , 0.9883  , 0.9883  ],
         [0.9883  , 0.9883  , 0.9883  ],
         [0.9883  , 0.9883  , 0.9883  ]],

        [[0.0745  , 0.0804  , 0.0784  ],
         [0.4727  , 0.496   , 0.4805  ],
         [0.517   , 0.5547  , 0.5293  ],
         ...,
         [0.977   , 0.977   , 0.977   ],
         [0.9707  , 0.9707  , 0.9707  ],
         [0.9766  , 0.9766  , 0.9766  ]],

        [[0.0745  , 0.08234 , 0.08234 ],
         [0.4746  , 0.498   , 0.4844  ],
         [0.5137  , 0.5527  , 0.5254  ],
         ...,
         [0.4087  , 0.4087  , 0.4087  ],
         [0.3977  , 0.3977  , 0.3977  ],
         [0.3484  , 0.3484  , 0.3484  ]]],


       ...,


       [[[0.01778 , 0.01778 , 0.01778 ],
         [0.08307 , 0.08307 , 0.08307 ],
         [0.1046  , 0.1046  , 0.1046  ],
         ...,
         [0.1659  , 0.1744  , 0.1631  ],
         [0.08594 , 0.0938  , 0.0899  ],
         [0.00784 , 0.011765, 0.011765]],

        [[0.0795  , 0.0795  , 0.0795  ],
         [0.4434  , 0.4434  , 0.4434  ],
         [0.3796  , 0.3796  , 0.3796  ],
         ...,
         [0.2612  , 0.2708  , 0.2573  ],
         [0.1079  , 0.1157  , 0.11017 ],
         [0.00784 , 0.011765, 0.011765]],

        [[0.0664  , 0.0664  , 0.0664  ],
         [0.3572  , 0.3572  , 0.3572  ],
         [0.388   , 0.388   , 0.388   ],
         ...,
         [0.1753  , 0.1792  , 0.1674  ],
         [0.1013  , 0.1054  , 0.10144 ],
         [0.00772 , 0.01164 , 0.01152 ]],

        ...,

        [[0.08234 , 0.0784  , 0.0844  ],
         [0.512   , 0.512   , 0.516   ],
         [0.563   , 0.563   , 0.557   ],
         ...,
         [0.9844  , 0.9844  , 0.9844  ],
         [0.9883  , 0.988   , 0.9883  ],
         [0.9883  , 0.988   , 0.9883  ]],

        [[0.0843  , 0.08264 , 0.0843  ],
         [0.506   , 0.506   , 0.506   ],
         [0.549   , 0.549   , 0.549   ],
         ...,
         [0.972   , 0.972   , 0.972   ],
         [0.9766  , 0.9766  , 0.9766  ],
         [0.9727  , 0.9727  , 0.9727  ]],

        [[0.0843  , 0.0843  , 0.0843  ],
         [0.506   , 0.506   , 0.506   ],
         [0.547   , 0.547   , 0.547   ],
         ...,
         [0.4138  , 0.4138  , 0.4138  ],
         [0.3975  , 0.3975  , 0.3975  ],
         [0.3496  , 0.3496  , 0.3496  ]]],


       [[[0.01581 , 0.01581 , 0.01581 ],
         [0.0835  , 0.0835  , 0.0835  ],
         [0.1042  , 0.1042  , 0.1042  ],
         ...,
         [0.1631  , 0.1725  , 0.1611  ],
         [0.08594 , 0.0938  , 0.0899  ],
         [0.00784 , 0.011765, 0.011765]],

        [[0.07623 , 0.07623 , 0.07623 ],
         [0.442   , 0.442   , 0.442   ],
         [0.3748  , 0.3748  , 0.3748  ],
         ...,
         [0.2605  , 0.269   , 0.257   ],
         [0.1082  , 0.116   , 0.1118  ],
         [0.00784 , 0.011765, 0.011765]],

        [[0.0646  , 0.0646  , 0.0646  ],
         [0.3538  , 0.3538  , 0.3538  ],
         [0.3918  , 0.3918  , 0.3918  ],
         ...,
         [0.1735  , 0.1792  , 0.1655  ],
         [0.1013  , 0.10724 , 0.10333 ],
         [0.00772 , 0.01164 , 0.01152 ]],

        ...,

        [[0.08234 , 0.0784  , 0.0844  ],
         [0.512   , 0.512   , 0.516   ],
         [0.563   , 0.563   , 0.557   ],
         ...,
         [0.9844  , 0.9844  , 0.9844  ],
         [0.9883  , 0.988   , 0.9883  ],
         [0.9883  , 0.988   , 0.9883  ]],

        [[0.0843  , 0.08264 , 0.0843  ],
         [0.506   , 0.506   , 0.506   ],
         [0.549   , 0.549   , 0.549   ],
         ...,
         [0.972   , 0.972   , 0.972   ],
         [0.9766  , 0.9766  , 0.9766  ],
         [0.9727  , 0.9727  , 0.9727  ]],

        [[0.0843  , 0.0843  , 0.0843  ],
         [0.506   , 0.506   , 0.506   ],
         [0.547   , 0.547   , 0.547   ],
         ...,
         [0.4138  , 0.4138  , 0.4138  ],
         [0.3975  , 0.3975  , 0.3975  ],
         [0.3496  , 0.3496  , 0.3496  ]]],


       [[[0.01581 , 0.01581 , 0.01581 ],
         [0.0835  , 0.0835  , 0.0835  ],
         [0.1042  , 0.1042  , 0.1042  ],
         ...,
         [0.1624  , 0.1709  , 0.1592  ],
         [0.08594 , 0.0938  , 0.0899  ],
         [0.00784 , 0.011765, 0.011765]],

        [[0.07623 , 0.07623 , 0.07623 ],
         [0.442   , 0.442   , 0.442   ],
         [0.3748  , 0.3748  , 0.3748  ],
         ...,
         [0.2646  , 0.2747  , 0.261   ],
         [0.1082  , 0.116   , 0.11017 ],
         [0.00784 , 0.011765, 0.011765]],

        [[0.0646  , 0.0646  , 0.0646  ],
         [0.3538  , 0.3538  , 0.3538  ],
         [0.3918  , 0.3918  , 0.3918  ],
         ...,
         [0.1755  , 0.1792  , 0.1674  ],
         [0.1013  , 0.1054  , 0.10144 ],
         [0.00772 , 0.01164 , 0.01152 ]],

        ...,

        [[0.08234 , 0.0784  , 0.0844  ],
         [0.512   , 0.512   , 0.516   ],
         [0.563   , 0.563   , 0.557   ],
         ...,
         [0.9844  , 0.9844  , 0.9844  ],
         [0.9883  , 0.988   , 0.9883  ],
         [0.9883  , 0.988   , 0.9883  ]],

        [[0.0843  , 0.08264 , 0.0843  ],
         [0.506   , 0.506   , 0.506   ],
         [0.549   , 0.549   , 0.549   ],
         ...,
         [0.972   , 0.972   , 0.972   ],
         [0.9766  , 0.9766  , 0.9766  ],
         [0.9727  , 0.9727  , 0.9727  ]],

        [[0.0843  , 0.0843  , 0.0843  ],
         [0.506   , 0.506   , 0.506   ],
         [0.547   , 0.547   , 0.547   ],
         ...,
         [0.4138  , 0.4138  , 0.4138  ],
         [0.3975  , 0.3975  , 0.3975  ],
         [0.3496  , 0.3496  , 0.3496  ]]]], dtype=float16)

目標:label_training_ar

  • 類型:numpy 陣列

  • 形狀:(2697、30、2)

  • 示例:label_training_ar[10]

array([[0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.]])

3. VGG19 + LSTM Model

3-1。 代碼

base_model=keras.applications.VGG19(include_top=False, input_shape=(160, 160, 3), weights='imagenet')

image_model=keras.models.Sequential()
image_model.add(base_model)
image_model.add(keras.layers.Flatten())
image_model.add(keras.layers.Dense(4096, activation='relu', name='fc1'))
image_model.add(keras.layers.Dense(4096, activation='relu', name='fc2'))
image_model.add(keras.layers.Dense(1000, activation='softmax', name='predictions'))

chunk_size=4096
n_chunks=30
rnn_size=512

model=keras.models.Sequential()
model.add(keras.layers.TimeDistributed(image_model, input_shape=(30, 160, 160, 3)))

model.add(keras.layers.LSTM(rnn_size, input_shape=(n_chunks, chunk_size))) # (30, 4096)
model.add(keras.layers.Dense(1024))
model.add(keras.layers.Activation('relu'))
model.add(keras.layers.Dense(256))
model.add(keras.layers.Activation('sigmoid'))
model.add(keras.layers.Dense(2))
model.add(keras.layers.Activation('softmax'))

model.compile(loss='mean_squared_error', optimizer='adam', metrics=['accuracy'])

3-2。 Plot 圖像

在此處輸入圖像描述

4. Model配件(模型培訓)

epoch=100
batchS=30
history=model.fit(x=data_training_ar[0:2000], y=label_training_ar[0:2000], epochs=epoch,
                  validation_data=(data_training_ar[2000:], label_training_ar[2000:]),
                  callbacks=[checkpoint_cb], #keras.callbacks.ModelCheckpoint('210429_vc_13-02_checkpoint.h5', save_best_only=True)
                  batch_size=batchS, verbose=2)

如果可能,嘗試使用 PyCharm 並查看錯誤是否仍然存在? 還要檢查它是否是相同的錯誤。

我在 Google Colab 中運行 VGG 系列模型。 它相當快。

嘗試使用 Spyder 或只是記事本並直接在命令行上運行您的腳本。 這是為了確保您的問題與運行 Jupyter 的 web 服務器的超時無關。 它還將允許您查看完整的堆棧跟蹤。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM