[英]Transfer learning by loading weights partially
嗨,我有一個預訓練的網絡模型 1,如下所示:
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 3000, 200, 1) 0
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 3000, 200, 1) 26 input_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 1500, 200, 1) 0 conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 1500, 200, 1) 26 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 750, 200, 1) 0 conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 750, 200, 1) 26 max_pooling2d_2[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D) (None, 375, 200, 1) 0 conv2d_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 375, 200, 1) 26 max_pooling2d_3[0][0]
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D) (None, 187, 199, 1) 0 conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 128, 128, 1) 4321 max_pooling2d_4[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 128, 128, 16) 160 conv2d_5[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 128, 128, 16) 2320 conv2d_6[0][0]
__________________________________________________________________________________________________
max_pooling2d_5 (MaxPooling2D) (None, 64, 64, 16) 0 conv2d_7[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 64, 64, 32) 4640 max_pooling2d_5[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 64, 64, 32) 9248 conv2d_8[0][0]
__________________________________________________________________________________________________
max_pooling2d_6 (MaxPooling2D) (None, 32, 32, 32) 0 conv2d_9[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 32, 32, 64) 18496 max_pooling2d_6[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 32, 32, 64) 36928 conv2d_10[0][0]
__________________________________________________________________________________________________
max_pooling2d_7 (MaxPooling2D) (None, 16, 16, 64) 0 conv2d_11[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 16, 16, 128) 73856 max_pooling2d_7[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 16, 16, 128) 147584 conv2d_12[0][0]
__________________________________________________________________________________________________
max_pooling2d_8 (MaxPooling2D) (None, 8, 8, 128) 0 conv2d_13[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, 8, 8, 256) 295168 max_pooling2d_8[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D) (None, 8, 8, 256) 590080 conv2d_14[0][0]
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 16, 16, 256) 0 conv2d_15[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 16, 16, 384) 0 up_sampling2d_1[0][0]
conv2d_13[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D) (None, 16, 16, 128) 442496 concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D) (None, 16, 16, 128) 147584 conv2d_16[0][0]
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D) (None, 32, 32, 128) 0 conv2d_17[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 32, 32, 192) 0 up_sampling2d_2[0][0]
conv2d_11[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D) (None, 32, 32, 64) 110656 concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_19 (Conv2D) (None, 32, 32, 64) 36928 conv2d_18[0][0]
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D) (None, 64, 64, 64) 0 conv2d_19[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 64, 64, 96) 0 up_sampling2d_3[0][0]
conv2d_9[0][0]
__________________________________________________________________________________________________
conv2d_20 (Conv2D) (None, 64, 64, 32) 27680 concatenate_3[0][0]
__________________________________________________________________________________________________
conv2d_21 (Conv2D) (None, 64, 64, 32) 9248 conv2d_20[0][0]
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D) (None, 128, 128, 32) 0 conv2d_21[0][0]
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 128, 128, 48) 0 up_sampling2d_4[0][0]
conv2d_7[0][0]
__________________________________________________________________________________________________
conv2d_22 (Conv2D) (None, 128, 128, 16) 6928 concatenate_4[0][0]
__________________________________________________________________________________________________
conv2d_23 (Conv2D) (None, 128, 128, 16) 2320 conv2d_22[0][0]
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 261, 261, 8) 6280 conv2d_23[0][0]
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 267, 267, 4) 1572 conv2d_transpose_1[0][0]
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 300, 300, 2) 9250 conv2d_transpose_2[0][0]
__________________________________________________________________________________________________
conv2d_24 (Conv2D) (None, 300, 300, 1) 3 conv2d_transpose_3[0][0]
==================================================================================================
Total params: 1,983,850
Trainable params: 1,983,850
Non-trainable params: 0
我想從 conv2d_6 獲取 model1 中的權重到 end 並將它們加載到從 conv2d_3 到 end 的 model2(如下所述)中。
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 502, 200, 3) 0
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 502, 200, 3) 30 input_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 251, 200, 3) 0 conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 251, 200, 3) 48 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 125, 100, 3) 0 conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 126, 109, 3) 183 max_pooling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 127, 118, 3) 183 conv2d_transpose_1[0][0]
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 128, 128, 3) 201 conv2d_transpose_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 128, 128, 16) 448 conv2d_transpose_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 128, 128, 16) 2320 conv2d_3[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D) (None, 64, 64, 16) 0 conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 64, 64, 32) 4640 max_pooling2d_3[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 64, 64, 32) 9248 conv2d_5[0][0]
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D) (None, 32, 32, 32) 0 conv2d_6[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 32, 32, 64) 18496 max_pooling2d_4[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 32, 32, 64) 36928 conv2d_7[0][0]
__________________________________________________________________________________________________
max_pooling2d_5 (MaxPooling2D) (None, 16, 16, 64) 0 conv2d_8[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 16, 16, 128) 73856 max_pooling2d_5[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 16, 16, 128) 147584 conv2d_9[0][0]
__________________________________________________________________________________________________
max_pooling2d_6 (MaxPooling2D) (None, 8, 8, 128) 0 conv2d_10[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 8, 8, 256) 295168 max_pooling2d_6[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 8, 8, 256) 590080 conv2d_11[0][0]
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 16, 16, 256) 0 conv2d_12[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 16, 16, 384) 0 up_sampling2d_1[0][0]
conv2d_10[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 16, 16, 128) 442496 concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, 16, 16, 128) 147584 conv2d_13[0][0]
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D) (None, 32, 32, 128) 0 conv2d_14[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 32, 32, 192) 0 up_sampling2d_2[0][0]
conv2d_8[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D) (None, 32, 32, 64) 110656 concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D) (None, 32, 32, 64) 36928 conv2d_15[0][0]
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D) (None, 64, 64, 64) 0 conv2d_16[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 64, 64, 96) 0 up_sampling2d_3[0][0]
conv2d_6[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D) (None, 64, 64, 32) 27680 concatenate_3[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D) (None, 64, 64, 32) 9248 conv2d_17[0][0]
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D) (None, 128, 128, 32) 0 conv2d_18[0][0]
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 128, 128, 48) 0 up_sampling2d_4[0][0]
conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_19 (Conv2D) (None, 128, 128, 16) 6928 concatenate_4[0][0]
__________________________________________________________________________________________________
conv2d_20 (Conv2D) (None, 128, 128, 16) 2320 conv2d_19[0][0]
__________________________________________________________________________________________________
conv2d_transpose_4 (Conv2DTrans (None, 261, 261, 8) 6280 conv2d_20[0][0]
__________________________________________________________________________________________________
conv2d_transpose_5 (Conv2DTrans (None, 267, 267, 4) 1572 conv2d_transpose_4[0][0]
__________________________________________________________________________________________________
conv2d_transpose_6 (Conv2DTrans (None, 300, 300, 2) 9250 conv2d_transpose_5[0][0]
__________________________________________________________________________________________________
conv2d_21 (Conv2D) (None, 300, 300, 1) 3 conv2d_transpose_6[0][0]
==================================================================================================
Total params: 1,980,358
Trainable params: 1,980,358
Non-trainable params: 0
我從預訓練的 model 中獲取了weights_list = model1.get_weights()
,這個列表的長度是 54。我只是無法理解層中權重的索引。 Model1 有 43 層,索引讓我很困惑。 是否有任何通用方法來理解其他模型的索引? 在未來的模型中,我將在模型之間選擇特定的層權重。
您需要從特定圖層中提取權重。 單獨使用 model.get_weights() 會返回整個 model 的權重。
以下應該工作
model1_weights = model1.get_layer('conv2d_6').get_weights()
model2.get_layer('conv2d_3').set_weights(model1_weights)
編輯:要將其擴展到所有所需的層,最簡單的方法是為要傳輸的模型 1 創建層名稱列表 (mode1_layers),並為要傳輸的模型 2 (mode2_layers) 創建層名稱列表至。 然后將其轉移到 for 循環中,如下所示。
for i in range(len(model1_layers)):
layer_weights = model1.get_layer(model1_layers[i]).get_weights()
model2.get_layer(model2_layers[i]).set_weights(layer_weights)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.