[英]"ValueError: Input 0 is incompatible with layer gru1: expected ndim=3, found ndim=4"
[英]ValueError: Input 0 is incompatible with layer batch_normalization_1: expected ndim=3, found ndim=2
我正在尝试使用 DeepTriage 的实现,这是一种用于错误分类的深度学习方法。 该网站包括数据集、源代码和论文。 我知道这是一个非常具体的领域,但我会尽量简化。
在源代码中,他们使用以下代码部分定义了他们的方法“DBRNN-A:具有注意机制和长短期记忆单元(LSTM)的深度双向循环神经网络”:
input = Input(shape=(max_sentence_len,), dtype='int32')
sequence_embed = Embedding(vocab_size, embed_size_word2vec, input_length=max_sentence_len)(input)
forwards_1 = LSTM(1024, return_sequences=True, dropout_U=0.2)(sequence_embed)
attention_1 = SoftAttentionConcat()(forwards_1)
after_dp_forward_5 = BatchNormalization()(attention_1)
backwards_1 = LSTM(1024, return_sequences=True, dropout_U=0.2, go_backwards=True)(sequence_embed)
attention_2 = SoftAttentionConcat()(backwards_1)
after_dp_backward_5 = BatchNormalization()(attention_2)
merged = merge([after_dp_forward_5, after_dp_backward_5], mode='concat', concat_axis=-1)
after_merge = Dense(1000, activation='relu')(merged)
after_dp = Dropout(0.4)(after_merge)
output = Dense(len(train_label), activation='softmax')(after_dp)
model = Model(input=input, output=output)
model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=1e-4), metrics=['accuracy'])
SoftAttentionConcat
实现来自这里。 其余功能来自keras
。 此外,在论文中,他们共享的结构为:
在第一批标准化行中,它抛出此错误:
ValueError: Input 0 is incompatible with layer batch_normalization_1: expected ndim=3, found ndim=2
当我使用max_sentence_len=50
和max_sentence_len=200
我查看维度直到错误点,我看到这些形状:
Input -> (None, 50)
Embedding -> (None, 50, 200)
LSTM -> (None, None, 1024)
SoftAttentionConcat -> (None, 2048)
那么,有没有人看到这里的问题?
我猜问题是在 Keras 结构中使用 TensorFlow 代码或某些版本问题。
通过使用这里的问题和答案,我在 Keras 中实现了注意力机制,如下所示:
attention_1 = Dense(1, activation="tanh")(forwards_1)
attention_1 = Flatten()(attention_1) # squeeze (None,50,1)->(None,50)
attention_1 = Activation("softmax")(attention_1)
attention_1 = RepeatVector(num_rnn_unit)(attention_1)
attention_1 = Permute([2, 1])(attention_1)
attention_1 = multiply([forwards_1, attention_1])
attention_1 = Lambda(lambda xin: K.sum(xin, axis=1), output_shape=(num_rnn_unit,))(attention_1)
last_out_1 = Lambda(lambda xin: xin[:, -1, :])(forwards_1)
sent_representation_1 = concatenate([last_out_1, attention_1])
这很有效。 我用于实现的所有源代码都在GitHub 中。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.