簡體   English   中英

在 Keras 中為多標簽文本分類神經網絡創建一個帶有 Attention 的 LSTM 層

[英]Create an LSTM layer with Attention in Keras for multi-label text classification neural network

問候親愛的社區成員。 我正在創建一個神經網絡來預測多標簽 y。 具體來說,神經網絡接受 5 個輸入(演員列表、plot 摘要、電影特征、電影評論、標題)並嘗試預測電影類型的序列。 在神經網絡中,我使用嵌入層和全局最大池化層。

然而,我最近發現了 Recurrent Layers with Attention,這是近來機器學習翻譯中一個非常有趣的話題。 所以,我想知道我是否可以使用其中一個層,但只能使用 Plot 摘要輸入。 請注意,我不進行機器學習翻譯,而是進行文本分類。

我當前的神經網絡 state

def create_fit_keras_model(hparams,
                           version_data_control,
                           optimizer_name,
                           validation_method,
                           callbacks,
                           optimizer_version = None):

    sentenceLength_actors = X_train_seq_actors.shape[1]
    vocab_size_frequent_words_actors = len(actors_tokenizer.word_index)

    sentenceLength_plot = X_train_seq_plot.shape[1]
    vocab_size_frequent_words_plot = len(plot_tokenizer.word_index)

    sentenceLength_features = X_train_seq_features.shape[1]
    vocab_size_frequent_words_features = len(features_tokenizer.word_index)

    sentenceLength_reviews = X_train_seq_reviews.shape[1]
    vocab_size_frequent_words_reviews = len(reviews_tokenizer.word_index)

    sentenceLength_title = X_train_seq_title.shape[1]
    vocab_size_frequent_words_title = len(title_tokenizer.word_index)

    model = keras.Sequential(name='{0}_{1}dim_{2}batchsize_{3}lr_{4}decaymultiplier_{5}'.format(sequential_model_name, 
                                                                                                str(hparams[HP_EMBEDDING_DIM]), 
                                                                                                str(hparams[HP_HIDDEN_UNITS]),
                                                                                                str(hparams[HP_LEARNING_RATE]), 
                                                                                                str(hparams[HP_DECAY_STEPS_MULTIPLIER]),
                                                                                                version_data_control))
    actors = keras.Input(shape=(sentenceLength_actors,), name='actors_input')
    plot = keras.Input(shape=(sentenceLength_plot,), batch_size=hparams[HP_HIDDEN_UNITS], name='plot_input')
    features = keras.Input(shape=(sentenceLength_features,), name='features_input')
    reviews = keras.Input(shape=(sentenceLength_reviews,), name='reviews_input')
    title = keras.Input(shape=(sentenceLength_title,), name='title_input')

    emb1 = layers.Embedding(input_dim = vocab_size_frequent_words_actors + 2,
                            output_dim = 16, #hparams[HP_EMBEDDING_DIM], hyperparametered or fixed sized.
                            embeddings_initializer = 'uniform',
                            mask_zero = True,
                            input_length = sentenceLength_actors,
                            name="actors_embedding_layer")(actors)
    
    # encoded_layer1 = layers.GlobalAveragePooling1D(name="globalaveragepooling_actors_layer")(emb1)
    encoded_layer1 = layers.GlobalMaxPooling1D(name="globalmaxpooling_actors_layer")(emb1)
    
    emb2 = layers.Embedding(input_dim = vocab_size_frequent_words_plot + 2,
                            output_dim = hparams[HP_EMBEDDING_DIM],
                            embeddings_initializer = 'uniform',
                            mask_zero = True,
                            input_length = sentenceLength_plot,
                            name="plot_embedding_layer")(plot)
    # (Option 1)
    # encoded_layer2 = layers.GlobalMaxPooling1D(name="globalmaxpooling_plot_summary_Layer")(emb2)
 
    # (Option 2)
    emb2 = layers.Bidirectional(layers.LSTM(hparams[HP_EMBEDDING_DIM], return_sequences=True))(emb2)
    avg_pool = layers.GlobalAveragePooling1D()(emb2)
    max_pool = layers.GlobalMaxPooling1D()(emb2)
    conc = layers.concatenate([avg_pool, max_pool])

    # (Option 3)
    # emb2 = layers.Bidirectional(layers.LSTM(hparams[HP_EMBEDDING_DIM], return_sequences=True))(emb2)
    # emb2 = layers.Bidirectional(layers.LSTM(hparams[HP_EMBEDDING_DIM], return_sequences=True))(emb2)
    # emb2 = AttentionWithContext()(emb2)

    emb3 = layers.Embedding(input_dim = vocab_size_frequent_words_features + 2,
                            output_dim = hparams[HP_EMBEDDING_DIM],
                            embeddings_initializer = 'uniform',
                            mask_zero = True,
                            input_length = sentenceLength_features,
                            name="features_embedding_layer")(features)
    
    # encoded_layer3 = layers.GlobalAveragePooling1D(name="globalaveragepooling_movie_features_layer")(emb3)
    encoded_layer3 = layers.GlobalMaxPooling1D(name="globalmaxpooling_movie_features_layer")(emb3)
    
    emb4 = layers.Embedding(input_dim = vocab_size_frequent_words_reviews + 2,
                            output_dim = hparams[HP_EMBEDDING_DIM],
                            embeddings_initializer = 'uniform',
                            mask_zero = True,
                            input_length = sentenceLength_reviews,
                            name="reviews_embedding_layer")(reviews)
    
    # encoded_layer4 = layers.GlobalAveragePooling1D(name="globalaveragepooling_user_reviews_layer")(emb4)
    encoded_layer4 = layers.GlobalMaxPooling1D(name="globalmaxpooling_user_reviews_layer")(emb4)

    emb5 = layers.Embedding(input_dim = vocab_size_frequent_words_title + 2,
                            output_dim = hparams[HP_EMBEDDING_DIM],
                            embeddings_initializer = 'uniform',
                            mask_zero = True,
                            input_length = sentenceLength_title,
                            name="title_embedding_layer")(title)
    
    # encoded_layer5 = layers.GlobalAveragePooling1D(name="globalaveragepooling_movie_title_layer")(emb5)
    encoded_layer5 = layers.GlobalMaxPooling1D(name="globalmaxpooling_movie_title_layer")(emb5)

    merged = layers.concatenate([encoded_layer1, conc, encoded_layer3, encoded_layer4, encoded_layer5], axis=-1) #(Option 2)
    # merged = layers.concatenate([encoded_layer1, emb2, encoded_layer3, encoded_layer4, encoded_layer5], axis=-1) #(Option 3)

    dense_layer_1 = layers.Dense(hparams[HP_HIDDEN_UNITS],
                                 kernel_regularizer=regularizers.l2(neural_network_parameters['l2_regularization']),
                                 activation=neural_network_parameters['dense_activation'],
                                 name="1st_dense_hidden_layer_concatenated_inputs")(merged)
    
    layers.Dropout(neural_network_parameters['dropout_rate'])(dense_layer_1)
    
    output_layer = layers.Dense(neural_network_parameters['number_target_variables'],
                                activation=neural_network_parameters['output_activation'],
                                name='output_layer')(dense_layer_1)

    model = keras.Model(inputs=[actors, plot, features, reviews, title], outputs=output_layer, name='{0}_{1}dim_{2}batchsize_{3}lr_{4}decaymultiplier_{5}'.format(sequential_model_name, 
                                                                                                                                                                  str(hparams[HP_EMBEDDING_DIM]), 
                                                                                                                                                                  str(hparams[HP_HIDDEN_UNITS]),
                                                                                                                                                                  str(hparams[HP_LEARNING_RATE]), 
                                                                                                                                                                  str(hparams[HP_DECAY_STEPS_MULTIPLIER]),
                                                                                                                                                                  version_data_control))
    print(model.summary())
    
#     pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.0,
#                                                             final_sparsity=0.4,
#                                                             begin_step=600,
#                                                             end_step=1000)
    
#     model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, pruning_schedule=pruning_schedule)
    
    if optimizer_name=="adam" and optimizer_version is None:
        
        optimizer = optimizer_adam_v2(hparams)
        
    elif optimizer_name=="sgd" and optimizer_version is None:
        
        optimizer = optimizer_sgd_v1(hparams, "no decay")
        
    elif optimizer_name=="rmsprop" and optimizer_version is None:
        
        optimizer = optimizer_rmsprop_v1(hparams)

    print("here: {0}".format(optimizer.lr))

    lr_metric = [get_lr_metric(optimizer)]
    
    if type(get_lr_metric(optimizer)) in (float, int):

        print("Learning Rate's type is Float or Integer")
        model.compile(optimizer=optimizer,
                      loss=neural_network_parameters['model_loss'],
                      metrics=neural_network_parameters['model_metric'] + lr_metric, )
    else:
        print("Learning Rate's type is not Float or Integer, but rather {0}".format(type(lr_metric)))
        model.compile(optimizer=optimizer,
                      loss=neural_network_parameters['model_loss'],
                      metrics=neural_network_parameters['model_metric'], ) #+ lr_metric

您將在上面的結構中看到我有 5 個輸入層、5 個嵌入層,然后我僅在 Plot 匯總輸入中在 LSTM 上應用雙向層。

但是,使用 Plot 摘要上的當前雙向方法,我收到以下錯誤。 我的問題是如何在文本分類中利用注意力而不解決下面的錯誤。 所以,不要評論這個錯誤的解決方案。

在此處輸入圖像描述

我的問題是關於如何為 plot 摘要(輸入 2)提出如何創建循環層的建議。 此外,不要猶豫,在評論中寫下任何可能有助於我在 Keras 中實現這一目標的文章。

如果需要有關神經網絡結構的任何其他信息,我將隨時為您服務。

如果你覺得上面的神經網絡很復雜,我可以做一個簡單的版本。 然而,上面是我原來的神經網絡,所以我希望任何建議都基於那個 nn。


編輯:2020 年 12 月 14 日

在此處找到包含我要執行的代碼的 colab 筆記本。 該代碼包含兩個答案,一個是在評論中提出的(來自一個已經回答的問題,另一個是對我的問題的正式回答。

@MarcoCerliani 提出的第一種方法有效。 雖然,我也想要第二種工作方式。 @Allohvk 的方法(這兩種方法都在附加的 colab 的運行時單元 [21] 中實現) 后者目前不起作用。 我得到的最新錯誤是:

ValueError: 圖層 globalmaxpooling_plot_summary_Layer 的輸入 0 與圖層不兼容:預期 ndim=3,發現 ndim=2。 收到的完整形狀:[無,100]

我通過從我的神經網絡結構中刪除globalmaxpooling_plot_summary_Layer解決了我編輯的最新錯誤。

讓我總結一下意圖。 您想增加對代碼的關注。 您的任務是序列分類任務,而不是 seq-seq 翻譯器。 你並不真正關心它的完成方式,所以你可以不調試上面的錯誤,但只需要一段工作代碼。 我們這里的主要輸入是由您想要增加注意力的“n”個單詞組成的電影評論。

假設您嵌入評論並將其傳遞給 LSTM 層。 現在您想要“關注” LSTM 層的所有隱藏狀態,然后生成分類(而不是僅使用編碼器的最后一個隱藏 state)。 所以需要插入一個注意力層。 准系統實現如下所示:

    def __init__(self):    
        ##Nothing special to be done here
        super(peel_the_layer, self).__init__()
        
    def build(self, input_shape):
        ##Define the shape of the weights and bias in this layer
        ##This is a 1 unit layer. 
        units=1
        ##last index of the input_shape is the number of dimensions of the prev
        ##RNN layer. last but 1 index is the num of timesteps
        self.w=self.add_weight(name="att_weights", shape=(input_shape[-1], units), initializer="normal") #name property is useful for avoiding RuntimeError: Unable to create link.
        self.b=self.add_weight(name="att_bias", shape=(input_shape[-2], units), initializer="zeros")
        super(peel_the_layer,self).build(input_shape)
        
    def call(self, x):
        ##x is the input tensor..each word that needs to be attended to
        ##Below is the main processing done during training
        ##K is the Keras Backend import
        e = K.tanh(K.dot(x,self.w)+self.b)
        a = K.softmax(e, axis=1)
        output = x*a
        
        ##return the ouputs. 'a' is the set of attention weights
        ##the second variable is the 'attention adjusted o/p state' or context
        return a, K.sum(output, axis=1)

現在在 LSTM 之后和 Dense output 層之前調用上述注意力層。

        a, context = peel_the_layer()(lstm_out)
        ##context is the o/p which be the input to your classification layer
        ##a is the set of attention weights and you may want to route them to a display

您可以在此基礎上進行構建,因為您似乎想使用其他功能來讓電影評論得出最終的觀點。 注意主要適用於評論......如果句子很長,就會看到好處。

更多具體細節請參考https://towardsdatascience.com/create-your-own-custom-attention-layer-understand-all-flavours-2201b5e8be9e

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM