簡體   English   中英

ResNet(2D 圖像)與全連接網絡(1D 輸入)的串聯

[英]Concatenation of ResNet (2D images) with fully-connected network (1D input)

我通過以下方式在 Keras (TensorFlow 2) 中使用預構建的 ResNet:

from tensorflow.keras.applications.resnet50 import ResNet50
base_model = ResNet50(weights=None, include_top=False, input_shape=(39,39,3))
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.5)(x)
output_tensor = Dense(self.num_classes, activation='softmax')(x)
cnn_model = Model(inputs=base_model.input, outputs=output_tensor)
opt = Adam(lr=0.001)
cnn_model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy', tf.keras.metrics.AUC()])

model ( base_model.input ) 的輸入是39 x 39 x 3圖像。 此外,我現在還想向 model (即20 x 1 )提供一個帶有附加信息的 20 維向量。 我可以通過兩種不同的方式做到這一點:

  1. GlobalAveragePooling2D步驟之后附加 20 維向量。
  2. 為 20 維向量創建一個額外的全連接網絡,並在GlobalAveragePooling2D步驟之后將這個全連接網絡的 output 連接到上述 ResNet。 理想情況下,兩個網絡都是同時訓練的,但我不知道這是否可能。

我可以為這兩個選項調整我的模型還是不起作用?

是的,這兩個選項都有意義,並且可以使用 Keras。 對於#2,您可以定義另一個 model ,它將 20D 向量作為輸入並將其傳遞給全連接層,然后將 output 與池化層的 output 連接起來。 對於這兩個選項,您必須調整最終的 model 輸入以包括 base_model 輸入和 20D 向量。

應該這樣做,注釋掉密集層以在全局平均池化之后將它們連接起來。

from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.layers import GlobalAveragePooling2D, Dropout, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
import tensorflow as tf


base_model = ResNet50(weights=None, include_top=False, input_shape=(39, 39, 3))
x1 = base_model.output
x1 = GlobalAveragePooling2D()(x1)
x1 = Dropout(0.5)(x1)

input_2 = tf.keras.layers.Input(shape=(20, 1))
x2 = tf.keras.layers.Flatten()(input_2)
# comment this if needed.
x2 = tf.keras.layers.Dense(16, activation='relu')(x2)

x = tf.keras.layers.Concatenate()([x1, x2])

output_tensor = Dense(self.num_classes, activation='softmax')(x)
cnn_model = Model(inputs=[base_model.input, input_2], outputs=output_tensor)
opt = Adam(lr=0.001)
cnn_model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy', tf.keras.metrics.AUC()])
print(cnn_model.summary())

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM