簡體   English   中英

Keras LSTM 層值錯誤:尺寸必須相等,但為 17 和 2

[英]Keras LSTM Layer ValueError: Dimensions must be equal, but are 17 and 2

我正在為一個多類任務開發一個基本的 RNN model,我在 output 尺寸方面遇到了一些問題。

這是我的輸入/輸出形狀:

input.shape = (50000, 2, 5) # (samples, features, feature_len)
output.shape = (50000, 17, 185) # (samples, features, feature_len) <-- one hot encoded 

input[0].shape =  (2, 5)
output[0].shape = (17, 185)

這是我的 model,使用 Keras 功能 API:

inp = tf.keras.Input(shape=(2, 5,))

x = tf.keras.layers.LSTM(128, input_shape=(2, 5,), return_sequences=True, activation='relu')(inp)
out = tf.keras.layers.Dense(185, activation='softmax')(x)

model = tf.keras.models.Model(inputs=inp, outputs=out)

這是我的model.summary()

Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 2, 5)]            0
_________________________________________________________________
lstm (LSTM)                  (None, 2, 128)            68608
_________________________________________________________________
dense (Dense)                (None, 2, 185)            23865
=================================================================
Total params: 92,473
Trainable params: 92,473
Non-trainable params: 0
_________________________________________________________________

然后我編譯 model 並運行fit()

model.compile(optimizer='adam',
              loss=tf.nn.softmax_cross_entropy_with_logits,
              metrics='accuracy')

model.fit(x=input, y=output, epochs=5)

我得到一個尺寸錯誤:

ValueError: Dimensions must be equal, but are 17 and 2 for '{{node Equal}} = Equal[T=DT_INT64, incompatible_shape_error=true](ArgMax, ArgMax_1)' with input shapes: [?,17], [?,2].

錯誤很明顯,model output 尺寸為2而我的 output 尺寸為17 ,雖然我理解這個問題,但我找不到解決方法,有什么想法嗎?

我認為您的 output 形狀不是“輸出 [0].shape = (17, 185)”,而是“密集 (Dense) (None, 2, 185)”。

您需要更改 output 形狀或更改圖層結構。

當您指定return_sequences=True時,LSTM output 是encoder_outputs的列表。 因此; 我建議只使用最后一項encoder_outputs作為密集層的輸入。 您可以查看此文檔鏈接的示例部分。 它可能會幫助你。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM