簡體   English   中英

LSTM 的輸入維度

[英]Input dimension for LSTM

我有這樣的模型。

它從最后 5 個條目中預測下一個。

(這意味着給item[0] ~ item[4]預測item[4]

我有訓練數據(94,11026)

所以我將(93, 5, 11026)作為輸入並給出(93, 11026)作為驗證。 (然后將其分開用於測試和值)

    n_hidden = 512
    model = Sequential()
    model.add(LSTM(n_hidden, activation=None, input_shape=(5,11026), return_sequences=True))
    model.add(Dense(n_hidden, activation="linear")) 
    model.add(Dense(n_in, activation="linear"))
    opt = Adam(lr=0.001)
    model.compile(loss='mse', optimizer=opt)
    model.summary()

然后 history = model.fit(x, y, epochs=epoch, batch_size=10,validation_data=(val_x, val_y))

但是會發生此錯誤。

我猜它說我應該給出形狀[10,11026]而不是[10,5,11026]

但是我想要做的是給他們最后 5 個並預測下一個。

為什么會出錯? (我認為它適用於 simpleRNN)

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
lstm (LSTM)                  (None, 5, 512)            23631872  
_________________________________________________________________
dense (Dense)                (None, 5, 512)            262656    
_________________________________________________________________
dense_1 (Dense)              (None, 5, 11026)          5656338   
=================================================================
Total params: 29,550,866
Trainable params: 29,550,866
Non-trainable params: 0
_________________________________________________________________
Traceback (most recent call last):
  File "manage.py", line 22, in <module>
    main()
  File "manage.py", line 18, in main
    execute_from_command_line(sys.argv)
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/django/core/management/__init__.py", line 401, in execute_from_command_line
    utility.execute()
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/django/core/management/__init__.py", line 395, in execute
    self.fetch_command(subcommand).run_from_argv(self.argv)
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/django/core/management/base.py", line 330, in run_from_argv
    self.execute(*args, **cmd_options)
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/django/core/management/base.py", line 371, in execute
    output = self.handle(*args, **options)
  File "/Users/whitebear/CodingWorks/httproot/aiwave/defapp/management/commands/learn.py", line 205, in handle
    history = model.fit(f.x, f.y, epochs=epoch, batch_size=10,validation_data=(f.val_x, f.val_y))
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 1098, in fit
    tmp_logs = train_function(iterator)
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 840, in _call
    return self._stateless_fn(*args, **kwds)
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2829, in __call__
    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1848, in _filtered_call
    cancellation_manager=cancellation_manager)
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1924, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 550, in call
    ctx=ctx)
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError:  Incompatible shapes: [10,5,11026] vs. [10,11026]
     [[node gradient_tape/mean_squared_error/BroadcastGradientArgs (defined at /Users/whitebear/CodingWorks/httproot/aiwave/defapp/management/commands/learn.py:205) ]] [Op:__inference_train_function_2084]

Function call stack:
train_function

您的模型預測 5 個項目而不是 1 個。

嘗試:

model.add(LSTM(n_hidden, activation=None, input_shape=(5,11026)))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM