簡體   English   中英

ValueError:維度必須等於 RNN - Keras

[英]ValueError: Dimensions must be equal RNN - Keras

我正在構建一個具有不同 X 和 Y 長度的 RNN 模型。
從 NLP 我知道這應該是可能的(即 - 如果你有一個翻譯模型,你將不會有相同長度的輸入句子/單詞和輸出句子/單詞......)

不幸的是,這對我不起作用,我收到以下錯誤:

ValueError: Dimensions must be equal, but are 3 and 405 for '{{node mean_absolute_error/sub}} = Sub[T=DT_FLOAT](sequential_47/time_distributed_46/Reshape_1, IteratorGetNext:1)' with input shapes: [?,3,1], [?,405,1].

(下面的完整錯誤)

我也在網上查了一下,發現 Chollet 本人似乎參與了這個帖子,並在 2021 年將其關閉,但那里的最后一條評論(2019 年)沒有解決,並且和我有同樣的問題。

我正在使用 google colab,因此很確定這不是舊版本的問題( tf.__version__ ==> 2.8.2tf.keras.__version__ ==> 2.8.0

我的模型是:

model_gru = Sequential()
model_gru.add(GRU(75, return_sequences=True,input_shape=(train_X.shape[1], train_X.shape[2])))# , unroll=True))
model_gru.add(GRU(units=30, return_sequences=True))
model_gru.add(TimeDistributed(Dense(1)))
model_gru.compile(loss='mae', optimizer='adam')
model_gru.summary()

gru_history = model_gru.fit(train_X, train_y, epochs=30, batch_size=64, validation_data=(test_X, test_y), shuffle=False)

知道如何解決這個問題嗎?


完整錯誤:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-212-1b00467cbad2> in <module>()
----> 1 gru_history = model_gru.fit(train_X, train_y, epochs=30, batch_size=64, validation_data=(test_X, test_y), shuffle=False, callbacks=[callback])

1 frames
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in autograph_handler(*args, **kwargs)
   1145           except Exception as e:  # pylint:disable=broad-except
   1146             if hasattr(e, "ag_error_metadata"):
-> 1147               raise e.ag_error_metadata.to_exception(e)
   1148             else:
   1149               raise

ValueError: in user code:

    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1021, in train_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1010, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1000, in run_step  **
        outputs = model.train_step(data)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 860, in train_step
        loss = self.compute_loss(x, y, y_pred, sample_weight)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 919, in compute_loss
        y, y_pred, sample_weight, regularization_losses=self.losses)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py", line 201, in __call__
        loss_value = loss_obj(y_t, y_p, sample_weight=sw)
    File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 141, in __call__
        losses = call_fn(y_true, y_pred)
    File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 245, in call  **
        return ag_fn(y_true, y_pred, **self._fn_kwargs)
    File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 1457, in mean_absolute_error
        return backend.mean(tf.abs(y_pred - y_true), axis=-1)

    ValueError: Dimensions must be equal, but are 3 and 405 for '{{node mean_absolute_error/sub}} = Sub[T=DT_FLOAT](sequential_47/time_distributed_46/Reshape_1, IteratorGetNext:1)' with input shapes: [?,3,1], [?,405,1].

RNN 單元需要相同長度的輸入和標簽。

RNN 為每個相應的輸入生成一個輸出。 無論您使用的是 RNN、GRU、LSTM 等,RNN 都會接收您的序列,然后從序列中的第一個元素開始,它會生成輸出和隱藏狀態。 它存儲輸出,然后傳遞隱藏狀態和序列中的下一個元素,並生成下一個輸出和新的隱藏狀態。 一旦為序列中的所有元素完成此操作,它將為您提供每個時間步的輸出和最終的隱藏狀態。 所以輸入的長度和標簽的長度必須相同。

RNN 翻譯的工作原理是對輸入進行編碼,然后將編碼后的狀態傳遞給解碼器,該解碼器按順序對其進行解碼。 即使用 RNN 的翻譯不是由單個 RNN 完成整個工作。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM