簡體   English   中英

連接 keras 層時出錯:“連接”層需要具有匹配形狀的輸入

[英]Error while concatenating keras layers: A `Concatenate` layer requires inputs with matching shapes

我有這個模型,稱為分層注意力網絡 在此處輸入圖片說明

其中建議進行文件分類。 我對句子詞使用 word2vec 嵌入,我想在 A 點連接另一個句子級嵌入(見圖)。

我將它用於包含 3 個句子的文檔; 模型總結: 在此處輸入圖片說明

word_input = Input(shape=(self.max_senten_len,), dtype='float32')
word_sequences = self.get_embedding_layer()(word_input)
word_lstm = Bidirectional(self.hyperparameters['rnn'](self.hyperparameters['rnn_units'], return_sequences=True, kernel_regularizer=kernel_regularizer))(word_sequences)
word_dense = TimeDistributed(Dense(self.hyperparameters['dense_units'], kernel_regularizer=kernel_regularizer))(word_lstm)
word_att = AttentionWithContext()(word_dense)
wordEncoder = Model(word_input, word_att)
sent_input = Input(shape=(self.max_senten_num, self.max_senten_len), dtype='float32')
sent_encoder = TimeDistributed(wordEncoder)(sent_input)

""" I added these following 2 lines. The dimension of self.training_features is (number of training rows, 3, 512). 512 is the dimension of the sentence-level embedding.  """
USE = Input(shape=(self.training_features.shape[1], self.training_features.shape[2]), name='USE_branch')
merge = concatenate([sent_encoder, USE], axis=1)

sent_lstm = Bidirectional(self.hyperparameters['rnn'](self.hyperparameters['rnn_units'], return_sequences=True, kernel_regularizer=kernel_regularizer))(merge)
sent_dense = TimeDistributed(Dense(self.hyperparameters['dense_units'], kernel_regularizer=kernel_regularizer))(sent_lstm)
sent_att = Dropout(dropout_regularizer)(AttentionWithContext()(sent_dense))
preds = Dense(len(self.labelencoder.classes_))(sent_att)
self.model = Model(sent_input, preds)

當我編譯上述代碼時,出現以下錯誤:

ValueError: Concatenate層需要具有匹配形狀( Concatenate軸除外)的輸入。 得到輸入形狀:[(None, 3, 128), (None, 3, 514)]

我指定了連接軸 = 1,以連接 (3) 句子的數量,但我不知道為什么我仍然收到錯誤。

這是因為如果您指定該軸,則形狀不匹配。 如果您這樣做,這將起作用:

merge = concatenate([sent_encoder, USE], axis=-1)

現在其余軸上沒有形狀沖突

錯誤是由於兩行:

merge = concatenate([sent_encoder, USE], axis=1)
# should be:
merge = concatenate([sent_encoder, USE], axis=2) # or -1 as @mlRocks suggested

和線:

self.model = Model(sent_input, preds)
# should be:
self.model = Model([sent_input, USE], preds) # to define both inputs

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM