簡體   English   中英

通過for循環使用功能api創建keras輸入層?

[英]create keras input layers using the functional api through a for loop?

假設我需要動態生成特定於用戶的 keras 模型。 每個用戶都可以擁有可變數量的分類輸入,但是一旦您知道分類輸入的數量,手動構建模型就變得很簡單了。

我想編寫一個函數,給出每個分類變量的基數列表將返回一個適當的模型。 我對這個問題的第一次嘗試產生了以下解決方案,但執行這樣的字符串似乎不正確。

from keras.layers import Dense,Embedding,Input,Flatten,Add
from keras.models import Model

def build_model(input_cardinalities,num_outputs):
    layers = []
    inputs = []
    for i,cardinality in enumerate(input_cardinalities):
        exec("input{0} = Input(shape=[1], name='input{0}')".format(i))
        exec("embedding{0} =  Embedding({1}, 20, name='embedding{0}')(input{0})".format(i,cardinality))
        exec("vec{0} = Flatten(name='flatten{0}')(embedding{0})".format(i))
        exec("layers.append(vec{0})".format(i))
        exec("inputs.append(input{0})".format(i))
    context_layer = Add(layers)
    dense1 = Dense(50, name='Dense1',activation='relu')(context_layer)
    dense2 = Dense(num_outputs, name='Output', activation='softmax')(dense1)
    model = Model(inputs,dense2)
    model.compile('sgd','categorical_crossentropy')
    return model

我只是覺得像這樣執行字符串不太舒服,但這是我能想到的做我想做的事情的唯一方法。 我只是覺得應該有更好的方法來做到這一點。

實際上根本不需要使用exec ,您一次構建一個輸入/嵌入,然后將它們存儲到列表中。 這是正確的方法,它不需要exec

def build_model(input_cardinalities,num_outputs):
    layers = []
    inputs = []
    for i,cardinality in enumerate(input_cardinalities):
        input = Input(shape=[1], name='input{0}'.format(i))
        embedding =  Embedding(cardinality, 20, name='embedding{0}'.format(i))
        vec = Flatten(name='flatten{0}'.format(i))(embedding)
        layers.append(vec)
        inputs.append(input)
    context_layer = Add()(layers)
    dense1 = Dense(50, name='Dense1',activation='relu')(context_layer)
    dense2 = Dense(num_outputs, name='Output', activation='softmax')(dense1)
    model = Model(inputs,dense2)
    model.compile('sgd','categorical_crossentropy')
    return model

另請注意,我更正了Add()(layers)調用。

我遇到了類似的問題,我想使用 for 循環實例化多個輸入層。 首先,我嘗試重現上述內容,並僅完成上述答案的全面性(和打印模型摘要)。

from keras.layers import Dense,Embedding,Input,Flatten,Add
from keras.models import Model
from tensorflow.keras.utils import plot_model

def build_model(input_cardinalities,num_outputs):
    layers = []
    inputs = []
    for i,cardinality in enumerate(input_cardinalities):
        input = Input(shape=[1], name='input{0}'.format(i))
        embedding =  Embedding(cardinality, 20, name='embedding{0}'.format(i))(input)
        vec = Flatten(name='flatten{0}'.format(i))(embedding)
        layers.append(vec)
        inputs.append(input)
    context_layer = Add(name='context')(layers)
    dense1 = Dense(50, name='Dense1',activation='relu')(context_layer)
    dense2 = Dense(num_outputs, name='Output', activation='softmax')(dense1)
    model = Model(inputs,dense2)
    model.compile('sgd','categorical_crossentropy')
    return model

然后構建模型

input_cardinalities = [1,2,3]
num_outputs = 6
model = build_model(input_cardinalities, num_outputs)
model.summary()
plot_model(model, 'model.png', show_shapes=True)

輸出看起來像這樣

Model: "model_14"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input0 (InputLayer)             [(None, 1)]          0                                            
__________________________________________________________________________________________________
input1 (InputLayer)             [(None, 1)]          0                                            
__________________________________________________________________________________________________
input2 (InputLayer)             [(None, 1)]          0                                            
__________________________________________________________________________________________________
embedding0 (Embedding)          (None, 1, 20)        20          input0[0][0]                     
__________________________________________________________________________________________________
embedding1 (Embedding)          (None, 1, 20)        40          input1[0][0]                     
__________________________________________________________________________________________________
embedding2 (Embedding)          (None, 1, 20)        60          input2[0][0]                     
__________________________________________________________________________________________________
flatten0 (Flatten)              (None, 20)           0           embedding0[0][0]                 
__________________________________________________________________________________________________
flatten1 (Flatten)              (None, 20)           0           embedding1[0][0]                 
__________________________________________________________________________________________________
flatten2 (Flatten)              (None, 20)           0           embedding2[0][0]                 
__________________________________________________________________________________________________
context (Add)                   (None, 20)           0           flatten0[0][0]                   
                                                                 flatten1[0][0]                   
                                                                 flatten2[0][0]                   
__________________________________________________________________________________________________
Dense1 (Dense)                  (None, 50)           1050        context[0][0]                    
__________________________________________________________________________________________________
Output (Dense)                  (None, 6)            306         Dense1[0][0]                     
==================================================================================================
Total params: 1,476
Trainable params: 1,476
Non-trainable params: 0
__________________________________________________________________________________________________

為了更好的可視化,請在此處查看此模型的 plot_model() 返回

末尾缺少一個(input)

embedding =  Embedding(cardinality, 20, name='embedding{0}'.format(i))(input)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM