简体   繁体   English

对 keras seq2seq 模型的注意力层

[英]Attention layer to keras seq2seq model

I have seen the keras now comes with Attention Layer .我已经看到 keras 现在带有注意力层 However, I have some problem using it in my Seq2Seq model.但是,我在 Seq2Seq 模型中使用它时遇到了一些问题。

This is the working seq2seq model without attention:这是没有注意的工作 seq2seq 模型:

latent_dim = 300
embedding_dim = 200

clear_session()

# Encoder
encoder_inputs = Input(shape=(max_text_len, ))

# Embedding layer
enc_emb = Embedding(x_voc, embedding_dim,
                    trainable=True)(encoder_inputs)

# Encoder LSTM 1
encoder_lstm1 = Bidirectional(LSTM(latent_dim, return_sequences=True,
                     return_state=True, dropout=0.4,
                     recurrent_dropout=0.4))
(encoder_output1, forward_h1, forward_c1, backward_h1, backward_c1) = encoder_lstm1(enc_emb)

# Encoder LSTM 2
encoder_lstm2 = Bidirectional(LSTM(latent_dim, return_sequences=True,
                     return_state=True, dropout=0.4,
                     recurrent_dropout=0.4))
(encoder_output2, forward_h2, forward_c2, backward_h2, backward_c2) = encoder_lstm2(encoder_output1)

# Encoder LSTM 3
encoder_lstm3 = Bidirectional(LSTM(latent_dim, return_state=True,
                     return_sequences=True, dropout=0.4,
                     recurrent_dropout=0.4))
(encoder_outputs, forward_h, forward_c, backward_h, backward_c) = encoder_lstm3(encoder_output2)

state_h = Concatenate()([forward_h, backward_h])
state_c = Concatenate()([forward_c, backward_c])

# Set up the decoder, using encoder_states as the initial state
decoder_inputs = Input(shape=(None, ))

# Embedding layer
dec_emb_layer = Embedding(y_voc, embedding_dim, trainable=True)
dec_emb = dec_emb_layer(decoder_inputs)

# Decoder LSTM
decoder_lstm = LSTM(latent_dim*2, return_sequences=True,
                    return_state=True, dropout=0.4,
                    recurrent_dropout=0.2)
(decoder_outputs, decoder_fwd_state, decoder_back_state) = \
    decoder_lstm(dec_emb, initial_state=[state_h, state_c])

# Dense layer
decoder_dense = TimeDistributed(Dense(y_voc, activation='softmax'))
decoder_outputs = decoder_dense(decoder_outputs)

# Define the model
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

model.summary()

I have modified the model to add Attention like this ( this is after # Decoder LSTM and right before # Dense Layer ):我修改了模型以添加这样的注意力(这是在# Decoder LSTM ,就在# Dense Layer之前):

attn_out, attn_states = Attention()([encoder_outputs, decoder_outputs])

decoder_concat_input = Concatenate(axis=-1)([decoder_outputs, attn_out])

# Dense layer
decoder_dense = TimeDistributed(Dense(y_voc, activation='softmax'))
decoder_outputs = decoder_dense(decoder_concat_input)

This throws TypeError: Cannot iterate over a Tensor with unknown first dimension.这将引发TypeError: Cannot iterate over a Tensor with unknown first dimension.

How do I apply attention mechanism to my seq2seq model?如何将注意力机制应用于我的 seq2seq 模型? If keras Attention layer does not work and/or other models are easy to use, I am happy to use them as well.如果 keras 注意层不起作用和/或其他模型易于使用,我也很乐意使用它们。

This is how I run my model:这是我运行模型的方式:

model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy')

es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=2)

history = model.fit(
    [x_tr, y_tr[:, :-1]],
    y_tr.reshape(y_tr.shape[0], y_tr.shape[1], 1)[:, 1:],
    epochs=50,
    callbacks=[es],
    batch_size=128,
    verbose=1,
    validation_data=([x_val, y_val[:, :-1]],
                     y_val.reshape(y_val.shape[0], y_val.shape[1], 1)[:
                     , 1:]),
    )

The shape of x_tr is (89674, 300), y_tr[:, :-1] is (89674, 14). x_tr 的形状是 (89674, 300),y_tr[:, :-1] 是 (89674, 14)。 Similarly, the shape of x_val and y_val[:, :-1] are (9964, 300) and (9964, 14) repectively.同样,x_val 和 y_val[:, :-1] 的形状分别是 (9964, 300) 和 (9964, 14)。

You are using Attention layer from keras, it returns only a 3D tensor not two tensors.您正在使用 keras 的Attention层,它只返回一个 3D 张量而不是两个张量。

So your code must be:所以你的代码必须是:

attn_out = Attention()([encoder_outputs, decoder_outputs])
decoder_concat_input = Concatenate(axis=-1)([decoder_outputs, attn_out])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM