简体   繁体   English

在Python中使用Keras将27个字段的输入连接到LSTM层的输出

[英]Concatenate an input of 27 fields to the output of the LSTM layer using Keras in Python

I have an existing LSTM model that looks as follows: 我有一个现有的LSTM模型,如下所示:

model_glove1 = Sequential()
model_glove1.add(Embedding(vocabulary_size, 25, input_length=50, weights=[embedding_matrix25],trainable=False))
model_glove1.add(LSTM(32))
model_glove1.add(Dense(128, activation='relu'))
model_glove1.add(Dense(64, activation='relu'))
model_glove1.add(Dense(1, activation='softmax'))
model_glove1.compile(loss='binary_crossentropy',optimizer='adam',metrics['accuracy',auc_roc])
model_glove1.fit(data, np.array(train_y), batch_size=32,
epochs=4,
verbose=1,
validation_split=0.1,
shuffle=True)

I want to add an additional auxiliary input layer which is present in a dataframe of 27 columns . 我想添加一个附加的辅助输入层,该输入层存在于27列的数据框中。 I want that layer to be concatenated with the output of the LSTM layer. 我希望该层与LSTM层的输出连接在一起。 Is it possible ? 可能吗 ? If so how can I achieve it? 如果可以,我该如何实现?

Before using the code, please check the secondary input has the same dimension like output of LSTM layer. 在使用代码之前,请检查辅助输入的尺寸是否与LSTM层的输出相同。

Moreover, in model1_glove.fit() function, you need to provide two inputs 此外,在model1_glove.fit()函数中,您需要提供两个输入

def NNStructure():
    initial_input= Embedding(vocabulary_size, 25, input_length=50, weights= 
    [embedding_matrix25],trainable=False) 
    lstm = LSTM(32)(initial_input)   
    secondary_input = Input(shape=(Number_of_row,27))    
    merge = concatenate([lstm, secondary_input])
    first_dense = Dense(128, activation='relu')(merge)
    second_dense=Dense(64, activation='relu')(first_dense)
    output=Dense(1, activation='softmax')(second_dense)

    model_glove1 = Model(inputs=[initial_input, secondary_input], outputs=output)
    return model_glove1

model_glove1=NNStructure()
model_glove1.compile(loss='binary_crossentropy',optimizer='adam',metrics['accuracy',auc_roc])
model_glove1.fit(x=[data1,data2], y=np.array(train_y), batch_size=32,
epochs=4,
verbose=1,
validation_split=0.1,
shuffle=True)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM