簡體   English   中英

如何將層添加到 keras 功能 object(例如 InceptionResNetV2)

[英]How to add layers into keras functional object (e.g InceptionResNetV2)

我正在嘗試將層添加到 InceptionResNetV2(或任何其他可以通過 tf.keras.applications 導入的預訓練網絡)中。 我知道我可以將 object 添加到順序 model 或功能 model 中。 但是,當我這樣做時,我將無法訪問圖層中的單個輸出以在Grad-CAM 或類似應用程序中使用它們。

我現在正在使用以下 model 結構。 它有效,它可以被訓練。 但是,它不允許我訪問關於特定輸入和特定 output 的 InceptionResNetV2 最后一個卷積層的 output。

from tensorflow.keras import layers, models
InceptionResNetV2 = tf.keras.applications.inception_resnet_v2.InceptionResNetV2

def get_base():
    conv_base = InceptionResNetV2(weights=None, include_top=False, input_shape=(224, 224, 3))
    conv_base.trainable = False
    return(conv_base)


def get_model():
    base = get_base()

    inputs = tf.keras.Input(shape=(224, 224, 3))
    x = base(inputs, training=False)
    x = layers.Flatten()(x)
    x = layers.Dense(512, "relu")(x)
    x = layers.Dropout(0.25)(x)
    x = layers.Dense(256, "relu")(x)
    x = layers.Dropout(0.25)(x)
    dims = layers.Dense(2, name="Valence_Arousal")(x)
    expression = layers.Dense(2, name="Emotion_Category")(x)


    model = models.Model(inputs=[inputs], outputs=[expression, dims])
    return(model)

print(get_model().summary())

創建嵌套模型后很難擴展它們。 input_tensor參數傳遞給預訓練的 model 會得到預期的結果。

def get_model():

    inputs = tf.keras.Input(shape=(224, 224, 3))
    
    conv_base = InceptionResNetV2(weights=None, include_top=False, input_tensor = inputs)
    conv_base.trainable = False
    
    x = layers.Flatten()(conv_base.output)
    x = layers.Dense(512, "relu")(x)
    x = layers.Dropout(0.25)(x)
    x = layers.Dense(256, "relu")(x)
    x = layers.Dropout(0.25)(x)
    
    dims = layers.Dense(2, name="Valence_Arousal")(x)
    expression = layers.Dense(2, name="Emotion_Category")(x)


    model = models.Model(inputs=[inputs], outputs=[expression, dims])
    return(model)

Model總結:

input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
conv2d (Conv2D)                (None, 111, 111, 32  864         ['input_1[0][0]']                
                                )  
...

                                                           

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM