简体   繁体   English

用于多类文本分类的具有自注意力的 LSTM

[英]LSTM with self attention for multi class text classification

I am following the self attention in Keras in the following link How to add attention layer to a Bi-LSTM我在以下链接中关注 Keras 中的自我关注How to add attention layer to a Bi-LSTM

I want to apply BI LSTM for multi class text classification with 3 classes.我想将 BI LSTM 应用于 3 个类的多类文本分类。

I try o apply the attention in my code, but I got the error below, how can I solve this problem?我尝试在我的代码中应用注意力,但出现以下错误,我该如何解决这个问题? can anyone help me please?有人可以帮我吗?

Incompatible shapes: [100,3] vs. [64,3]
     [[Node: training_1/Adam/gradients/loss_11/dense_14_loss/mul_grad/BroadcastGradientArgs = BroadcastGradientArgs[T=DT_INT32, _class=["loc:@training_1/Adam/gradients/loss_11/dense_14_loss/mul_grad/Reshape_1"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](training_1/Adam/gradients/loss_11/dense_14_loss/mul_grad/Shape, training_1/Adam/gradients/loss_11/dense_14_loss/mul_grad/Shape_1)]]




class attention(Layer):
    
    def __init__(self, return_sequences=False):
        self.return_sequences = return_sequences
        super(attention,self).__init__()
        
    def build(self, input_shape):
        
        self.W=self.add_weight(name="att_weight", shape=(input_shape[-1],1),
                               initializer="normal")
        self.b=self.add_weight(name="att_bias", shape=(input_shape[1],1),
                               initializer="zeros")
        
        super(attention,self).build(input_shape)
        
    def call(self, x):
        
        e = K.tanh(K.dot(x,self.W)+self.b)
        a = K.softmax(e, axis=1)
        output = x*a
        
        if self.return_sequences:
            return output
        
        return K.sum(output, axis=1)




model = Sequential()
model.add(Embedding(17666, 100, input_length=409))
model.add(Bidirectional(LSTM(32, return_sequences=False)))
model.add(attention(return_sequences=True)) # receive 3D and output 2D
model.add(Dropout(0.3))
model.add(Dense(3, activation='softmax'))


model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])
model.summary()

from keras.callbacks import EarlyStopping
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=3)
history777=model.fit(x_train, y_train,
 batch_size=100,
 epochs=30,
 validation_data=(x_val, y_val),
 callbacks=[es])
the model:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_14 (Embedding)     (None, 409, 100)          1766600   
_________________________________________________________________
bidirectional_14 (Bidirectio (None, 64)                34048     
_________________________________________________________________
attention_14 (attention)     (None, 64)                128       
_________________________________________________________________
dropout_6 (Dropout)          (None, 64)                0         
_________________________________________________________________
dense_14 (Dense)             (None, 3)                 195       
=================================================================
Total params: 1,800,971
Trainable params: 1,800,971
Non-trainable params: 0
____




pay attention to how you set the return_sequence param in the LSTM and attention layers注意你如何在 LSTM 和注意力层中设置 return_sequence 参数

your output is 2D so the last return sequence must be set to False while the others must be set to True您的输出是 2D,所以最后一个返回序列必须设置为 False 而其他的必须设置为 True

Your model must be你的模型必须是

model = Sequential()
model.add(Embedding(max_words, emb_dim, input_length=max_len))
model.add(Bidirectional(LSTM(32, return_sequences=True))) # return_sequences=True
model.add(attention(return_sequences=False)) # return_sequences=False
model.add(Dropout(0.3))
model.add(Dense(3, activation='softmax'))

here the full example: https://colab.research.google.com/drive/13l5eAHS5uTUsdqyQNm1Dr4JEXg7Fl2Bo?usp=sharing这里是完整的例子: https : //colab.research.google.com/drive/13l5eAHS5uTUsdqyQNm1Dr4JEXg7Fl2Bo?usp=sharing

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM