繁体   English   中英

在 Keras 中为多标签文本分类神经网络创建一个带有 Attention 的 LSTM 层

[英]Create an LSTM layer with Attention in Keras for multi-label text classification neural network

问候亲爱的社区成员。 我正在创建一个神经网络来预测多标签 y。 具体来说,神经网络接受 5 个输入(演员列表、plot 摘要、电影特征、电影评论、标题)并尝试预测电影类型的序列。 在神经网络中,我使用嵌入层和全局最大池化层。

然而,我最近发现了 Recurrent Layers with Attention,这是近来机器学习翻译中一个非常有趣的话题。 所以,我想知道我是否可以使用其中一个层,但只能使用 Plot 摘要输入。 请注意,我不进行机器学习翻译,而是进行文本分类。

我当前的神经网络 state

def create_fit_keras_model(hparams,
                           version_data_control,
                           optimizer_name,
                           validation_method,
                           callbacks,
                           optimizer_version = None):

    sentenceLength_actors = X_train_seq_actors.shape[1]
    vocab_size_frequent_words_actors = len(actors_tokenizer.word_index)

    sentenceLength_plot = X_train_seq_plot.shape[1]
    vocab_size_frequent_words_plot = len(plot_tokenizer.word_index)

    sentenceLength_features = X_train_seq_features.shape[1]
    vocab_size_frequent_words_features = len(features_tokenizer.word_index)

    sentenceLength_reviews = X_train_seq_reviews.shape[1]
    vocab_size_frequent_words_reviews = len(reviews_tokenizer.word_index)

    sentenceLength_title = X_train_seq_title.shape[1]
    vocab_size_frequent_words_title = len(title_tokenizer.word_index)

    model = keras.Sequential(name='{0}_{1}dim_{2}batchsize_{3}lr_{4}decaymultiplier_{5}'.format(sequential_model_name, 
                                                                                                str(hparams[HP_EMBEDDING_DIM]), 
                                                                                                str(hparams[HP_HIDDEN_UNITS]),
                                                                                                str(hparams[HP_LEARNING_RATE]), 
                                                                                                str(hparams[HP_DECAY_STEPS_MULTIPLIER]),
                                                                                                version_data_control))
    actors = keras.Input(shape=(sentenceLength_actors,), name='actors_input')
    plot = keras.Input(shape=(sentenceLength_plot,), batch_size=hparams[HP_HIDDEN_UNITS], name='plot_input')
    features = keras.Input(shape=(sentenceLength_features,), name='features_input')
    reviews = keras.Input(shape=(sentenceLength_reviews,), name='reviews_input')
    title = keras.Input(shape=(sentenceLength_title,), name='title_input')

    emb1 = layers.Embedding(input_dim = vocab_size_frequent_words_actors + 2,
                            output_dim = 16, #hparams[HP_EMBEDDING_DIM], hyperparametered or fixed sized.
                            embeddings_initializer = 'uniform',
                            mask_zero = True,
                            input_length = sentenceLength_actors,
                            name="actors_embedding_layer")(actors)
    
    # encoded_layer1 = layers.GlobalAveragePooling1D(name="globalaveragepooling_actors_layer")(emb1)
    encoded_layer1 = layers.GlobalMaxPooling1D(name="globalmaxpooling_actors_layer")(emb1)
    
    emb2 = layers.Embedding(input_dim = vocab_size_frequent_words_plot + 2,
                            output_dim = hparams[HP_EMBEDDING_DIM],
                            embeddings_initializer = 'uniform',
                            mask_zero = True,
                            input_length = sentenceLength_plot,
                            name="plot_embedding_layer")(plot)
    # (Option 1)
    # encoded_layer2 = layers.GlobalMaxPooling1D(name="globalmaxpooling_plot_summary_Layer")(emb2)
 
    # (Option 2)
    emb2 = layers.Bidirectional(layers.LSTM(hparams[HP_EMBEDDING_DIM], return_sequences=True))(emb2)
    avg_pool = layers.GlobalAveragePooling1D()(emb2)
    max_pool = layers.GlobalMaxPooling1D()(emb2)
    conc = layers.concatenate([avg_pool, max_pool])

    # (Option 3)
    # emb2 = layers.Bidirectional(layers.LSTM(hparams[HP_EMBEDDING_DIM], return_sequences=True))(emb2)
    # emb2 = layers.Bidirectional(layers.LSTM(hparams[HP_EMBEDDING_DIM], return_sequences=True))(emb2)
    # emb2 = AttentionWithContext()(emb2)

    emb3 = layers.Embedding(input_dim = vocab_size_frequent_words_features + 2,
                            output_dim = hparams[HP_EMBEDDING_DIM],
                            embeddings_initializer = 'uniform',
                            mask_zero = True,
                            input_length = sentenceLength_features,
                            name="features_embedding_layer")(features)
    
    # encoded_layer3 = layers.GlobalAveragePooling1D(name="globalaveragepooling_movie_features_layer")(emb3)
    encoded_layer3 = layers.GlobalMaxPooling1D(name="globalmaxpooling_movie_features_layer")(emb3)
    
    emb4 = layers.Embedding(input_dim = vocab_size_frequent_words_reviews + 2,
                            output_dim = hparams[HP_EMBEDDING_DIM],
                            embeddings_initializer = 'uniform',
                            mask_zero = True,
                            input_length = sentenceLength_reviews,
                            name="reviews_embedding_layer")(reviews)
    
    # encoded_layer4 = layers.GlobalAveragePooling1D(name="globalaveragepooling_user_reviews_layer")(emb4)
    encoded_layer4 = layers.GlobalMaxPooling1D(name="globalmaxpooling_user_reviews_layer")(emb4)

    emb5 = layers.Embedding(input_dim = vocab_size_frequent_words_title + 2,
                            output_dim = hparams[HP_EMBEDDING_DIM],
                            embeddings_initializer = 'uniform',
                            mask_zero = True,
                            input_length = sentenceLength_title,
                            name="title_embedding_layer")(title)
    
    # encoded_layer5 = layers.GlobalAveragePooling1D(name="globalaveragepooling_movie_title_layer")(emb5)
    encoded_layer5 = layers.GlobalMaxPooling1D(name="globalmaxpooling_movie_title_layer")(emb5)

    merged = layers.concatenate([encoded_layer1, conc, encoded_layer3, encoded_layer4, encoded_layer5], axis=-1) #(Option 2)
    # merged = layers.concatenate([encoded_layer1, emb2, encoded_layer3, encoded_layer4, encoded_layer5], axis=-1) #(Option 3)

    dense_layer_1 = layers.Dense(hparams[HP_HIDDEN_UNITS],
                                 kernel_regularizer=regularizers.l2(neural_network_parameters['l2_regularization']),
                                 activation=neural_network_parameters['dense_activation'],
                                 name="1st_dense_hidden_layer_concatenated_inputs")(merged)
    
    layers.Dropout(neural_network_parameters['dropout_rate'])(dense_layer_1)
    
    output_layer = layers.Dense(neural_network_parameters['number_target_variables'],
                                activation=neural_network_parameters['output_activation'],
                                name='output_layer')(dense_layer_1)

    model = keras.Model(inputs=[actors, plot, features, reviews, title], outputs=output_layer, name='{0}_{1}dim_{2}batchsize_{3}lr_{4}decaymultiplier_{5}'.format(sequential_model_name, 
                                                                                                                                                                  str(hparams[HP_EMBEDDING_DIM]), 
                                                                                                                                                                  str(hparams[HP_HIDDEN_UNITS]),
                                                                                                                                                                  str(hparams[HP_LEARNING_RATE]), 
                                                                                                                                                                  str(hparams[HP_DECAY_STEPS_MULTIPLIER]),
                                                                                                                                                                  version_data_control))
    print(model.summary())
    
#     pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.0,
#                                                             final_sparsity=0.4,
#                                                             begin_step=600,
#                                                             end_step=1000)
    
#     model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, pruning_schedule=pruning_schedule)
    
    if optimizer_name=="adam" and optimizer_version is None:
        
        optimizer = optimizer_adam_v2(hparams)
        
    elif optimizer_name=="sgd" and optimizer_version is None:
        
        optimizer = optimizer_sgd_v1(hparams, "no decay")
        
    elif optimizer_name=="rmsprop" and optimizer_version is None:
        
        optimizer = optimizer_rmsprop_v1(hparams)

    print("here: {0}".format(optimizer.lr))

    lr_metric = [get_lr_metric(optimizer)]
    
    if type(get_lr_metric(optimizer)) in (float, int):

        print("Learning Rate's type is Float or Integer")
        model.compile(optimizer=optimizer,
                      loss=neural_network_parameters['model_loss'],
                      metrics=neural_network_parameters['model_metric'] + lr_metric, )
    else:
        print("Learning Rate's type is not Float or Integer, but rather {0}".format(type(lr_metric)))
        model.compile(optimizer=optimizer,
                      loss=neural_network_parameters['model_loss'],
                      metrics=neural_network_parameters['model_metric'], ) #+ lr_metric

您将在上面的结构中看到我有 5 个输入层、5 个嵌入层,然后我仅在 Plot 汇总输入中在 LSTM 上应用双向层。

但是,使用 Plot 摘要上的当前双向方法,我收到以下错误。 我的问题是如何在文本分类中利用注意力而不解决下面的错误。 所以,不要评论这个错误的解决方案。

在此处输入图像描述

我的问题是关于如何为 plot 摘要(输入 2)提出如何创建循环层的建议。 此外,不要犹豫,在评论中写下任何可能有助于我在 Keras 中实现这一目标的文章。

如果需要有关神经网络结构的任何其他信息,我将随时为您服务。

如果你觉得上面的神经网络很复杂,我可以做一个简单的版本。 然而,上面是我原来的神经网络,所以我希望任何建议都基于那个 nn。


编辑:2020 年 12 月 14 日

在此处找到包含我要执行的代码的 colab 笔记本。 该代码包含两个答案,一个是在评论中提出的(来自一个已经回答的问题,另一个是对我的问题的正式回答。

@MarcoCerliani 提出的第一种方法有效。 虽然,我也想要第二种工作方式。 @Allohvk 的方法(这两种方法都在附加的 colab 的运行时单元 [21] 中实现) 后者目前不起作用。 我得到的最新错误是:

ValueError: 图层 globalmaxpooling_plot_summary_Layer 的输入 0 与图层不兼容:预期 ndim=3,发现 ndim=2。 收到的完整形状:[无,100]

我通过从我的神经网络结构中删除globalmaxpooling_plot_summary_Layer解决了我编辑的最新错误。

让我总结一下意图。 您想增加对代码的关注。 您的任务是序列分类任务,而不是 seq-seq 翻译器。 你并不真正关心它的完成方式,所以你可以不调试上面的错误,但只需要一段工作代码。 我们这里的主要输入是由您想要增加注意力的“n”个单词组成的电影评论。

假设您嵌入评论并将其传递给 LSTM 层。 现在您想要“关注” LSTM 层的所有隐藏状态,然后生成分类(而不是仅使用编码器的最后一个隐藏 state)。 所以需要插入一个注意力层。 准系统实现如下所示:

    def __init__(self):    
        ##Nothing special to be done here
        super(peel_the_layer, self).__init__()
        
    def build(self, input_shape):
        ##Define the shape of the weights and bias in this layer
        ##This is a 1 unit layer. 
        units=1
        ##last index of the input_shape is the number of dimensions of the prev
        ##RNN layer. last but 1 index is the num of timesteps
        self.w=self.add_weight(name="att_weights", shape=(input_shape[-1], units), initializer="normal") #name property is useful for avoiding RuntimeError: Unable to create link.
        self.b=self.add_weight(name="att_bias", shape=(input_shape[-2], units), initializer="zeros")
        super(peel_the_layer,self).build(input_shape)
        
    def call(self, x):
        ##x is the input tensor..each word that needs to be attended to
        ##Below is the main processing done during training
        ##K is the Keras Backend import
        e = K.tanh(K.dot(x,self.w)+self.b)
        a = K.softmax(e, axis=1)
        output = x*a
        
        ##return the ouputs. 'a' is the set of attention weights
        ##the second variable is the 'attention adjusted o/p state' or context
        return a, K.sum(output, axis=1)

现在在 LSTM 之后和 Dense output 层之前调用上述注意力层。

        a, context = peel_the_layer()(lstm_out)
        ##context is the o/p which be the input to your classification layer
        ##a is the set of attention weights and you may want to route them to a display

您可以在此基础上进行构建,因为您似乎想使用其他功能来让电影评论得出最终的观点。 注意主要适用于评论......如果句子很长,就会看到好处。

更多具体细节请参考https://towardsdatascience.com/create-your-own-custom-attention-layer-understand-all-flavours-2201b5e8be9e

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM