[英]Passing multiple inputs to CNN model
我有代表域名中每個字符的整數向量和另一個表示時間軸信息的整數向量。 我需要將這兩個向量作為CNN模型的輸入,將域名分類為好或垃圾郵件。
例如,
矢量代表域名 - > 1 x 75矢量。 向量中的每個元素表示域名中的每個字符。 如果有1000個域名,那么它將是1000 x 75的形狀矩陣
矢量表示時間線信息 - > 1 x 1440矢量。 每個元素表示每分鍾從特定域發送的郵件數。 如果有1000個域名,那么它將是一個形狀為1000 x 1440的矩陣
如何將這兩個向量輸入到單個CNN模型?
我當前的模型只給出了域名作為輸入,
def build_model(max_features, maxlen):
"""Build CNN model"""
model = Sequential()
model.add(Embedding(max_features, 8, input_length=maxlen))
model.add(Convolution1D(6, 4, border_mode='same'))
model.add(Convolution1D(4, 4, border_mode='same'))
model.add(Convolution1D(2, 4, border_mode='same'))
model.add(Flatten())
#model.add(Dropout(0.2))
#model.add(Dense(2,activation='sigmoid'))
#model.add(Dense(180,activation='sigmoid'))
#model.add(Dropout(0.2))
model.add(Dense(2,activation='softmax'))
sgd = optimizers.SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['categorical_accuracy', 'f1score', 'precision', 'recall'])
謝謝!
在卷積中,您需要“長度”維度和“渠道”維度。
(在2D中,它們將是“寬度”,“高度”和“通道”)。
現在,我想不出任何方法可以將75個字符與1440分鍾聯系起來。 (也許你可以,如果你能說明如何,也許我們可以更好地工作)
這是我假設的:
所以,我們有兩個輸入:
from keras.layers import *
input1 = Input((75,))
input2 = Input((1440,))
只有域名應該通過嵌入層:
name = Embedding(max_features, 8, input_length=maxlen)(input1)
現在,重新整形以適應卷積輸入(None,length,channels)
。
# the embedding output is already (Batch, 75, 8) -- See: https://keras.io/layers/embeddings/
mails = Reshape((1440,1))(input2) #adding 1 channel at the end
平行卷積:
name = Conv1D( feel free to customize )(name)
name = Conv1D( feel free to customize )(name)
mails = Conv1D( feel free to customize )(mails)
mails = Conv1D( feel free to customize )(mails)
連接 - 由於它們具有完全不同的形狀,也許我們應該簡單地將它們平鋪(或者你可以想到花哨的操作來匹配它們)
name = Flatten()(name)
mails = Flatten()(mails)
out = Concatenate()([name,mails])
out = add your extra layers
out = Dense(2,activation='softmax')(out)
最后我們創建了模型:
from keras.models import Model
model = Model([input1,input2], out)
訓練如下:
model.fit([xName,xMails], Y, ....)
您可以使用Keras的功能API構建多輸入網絡。 為每個輸入維度單獨設置一維卷積網絡。 然后連接每個網絡的輸出,並將該連接的矢量傳遞到位於兩個其他網絡之上的一些共享的完全連接的層。
https://keras.io/getting-started/functional-api-guide/#multi-input-and-multi-output-models
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.