簡體   English   中英

通過部分加載權重進行遷移學習

[英]Transfer learning by loading weights partially

嗨,我有一個預訓練的網絡模型 1,如下所示:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 3000, 200, 1) 0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 3000, 200, 1) 26          input_1[0][0]                    
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 1500, 200, 1) 0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 1500, 200, 1) 26          max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 750, 200, 1)  0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 750, 200, 1)  26          max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 375, 200, 1)  0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 375, 200, 1)  26          max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 187, 199, 1)  0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 128, 128, 1)  4321        max_pooling2d_4[0][0]            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 128, 128, 16) 160         conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 128, 128, 16) 2320        conv2d_6[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_5 (MaxPooling2D)  (None, 64, 64, 16)   0           conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 64, 64, 32)   4640        max_pooling2d_5[0][0]            
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 64, 64, 32)   9248        conv2d_8[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_6 (MaxPooling2D)  (None, 32, 32, 32)   0           conv2d_9[0][0]                   
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 32, 32, 64)   18496       max_pooling2d_6[0][0]            
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 32, 32, 64)   36928       conv2d_10[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_7 (MaxPooling2D)  (None, 16, 16, 64)   0           conv2d_11[0][0]                  
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 16, 16, 128)  73856       max_pooling2d_7[0][0]            
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 16, 16, 128)  147584      conv2d_12[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_8 (MaxPooling2D)  (None, 8, 8, 128)    0           conv2d_13[0][0]                  
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 8, 8, 256)    295168      max_pooling2d_8[0][0]            
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 8, 8, 256)    590080      conv2d_14[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 16, 16, 256)  0           conv2d_15[0][0]                  
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 16, 16, 384)  0           up_sampling2d_1[0][0]            
                                                                 conv2d_13[0][0]                  
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 16, 16, 128)  442496      concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 16, 16, 128)  147584      conv2d_16[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 32, 32, 128)  0           conv2d_17[0][0]                  
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 32, 32, 192)  0           up_sampling2d_2[0][0]            
                                                                 conv2d_11[0][0]                  
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 32, 32, 64)   110656      concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 32, 32, 64)   36928       conv2d_18[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)  (None, 64, 64, 64)   0           conv2d_19[0][0]                  
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 64, 64, 96)   0           up_sampling2d_3[0][0]            
                                                                 conv2d_9[0][0]                   
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 64, 64, 32)   27680       concatenate_3[0][0]              
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 64, 64, 32)   9248        conv2d_20[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D)  (None, 128, 128, 32) 0           conv2d_21[0][0]                  
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 128, 128, 48) 0           up_sampling2d_4[0][0]            
                                                                 conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_22 (Conv2D)              (None, 128, 128, 16) 6928        concatenate_4[0][0]              
__________________________________________________________________________________________________
conv2d_23 (Conv2D)              (None, 128, 128, 16) 2320        conv2d_22[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 261, 261, 8)  6280        conv2d_23[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 267, 267, 4)  1572        conv2d_transpose_1[0][0]         
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 300, 300, 2)  9250        conv2d_transpose_2[0][0]         
__________________________________________________________________________________________________
conv2d_24 (Conv2D)              (None, 300, 300, 1)  3           conv2d_transpose_3[0][0]         
==================================================================================================
Total params: 1,983,850
Trainable params: 1,983,850
Non-trainable params: 0

我想從 conv2d_6 獲取 model1 中的權重到 end 並將它們加載到從 conv2d_3 到 end 的 model2(如下所述)中。

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 502, 200, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 502, 200, 3)  30          input_1[0][0]                    
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 251, 200, 3)  0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 251, 200, 3)  48          max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 125, 100, 3)  0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 126, 109, 3)  183         max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 127, 118, 3)  183         conv2d_transpose_1[0][0]         
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 128, 128, 3)  201         conv2d_transpose_2[0][0]         
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 128, 128, 16) 448         conv2d_transpose_3[0][0]         
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 128, 128, 16) 2320        conv2d_3[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 64, 64, 16)   0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 64, 64, 32)   4640        max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 64, 64, 32)   9248        conv2d_5[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 32, 32, 32)   0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 32, 32, 64)   18496       max_pooling2d_4[0][0]            
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 32, 32, 64)   36928       conv2d_7[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_5 (MaxPooling2D)  (None, 16, 16, 64)   0           conv2d_8[0][0]                   
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 16, 16, 128)  73856       max_pooling2d_5[0][0]            
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 16, 16, 128)  147584      conv2d_9[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_6 (MaxPooling2D)  (None, 8, 8, 128)    0           conv2d_10[0][0]                  
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 8, 8, 256)    295168      max_pooling2d_6[0][0]            
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 8, 8, 256)    590080      conv2d_11[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 16, 16, 256)  0           conv2d_12[0][0]                  
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 16, 16, 384)  0           up_sampling2d_1[0][0]            
                                                                 conv2d_10[0][0]                  
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 16, 16, 128)  442496      concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 16, 16, 128)  147584      conv2d_13[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 32, 32, 128)  0           conv2d_14[0][0]                  
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 32, 32, 192)  0           up_sampling2d_2[0][0]            
                                                                 conv2d_8[0][0]                   
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 32, 32, 64)   110656      concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 32, 32, 64)   36928       conv2d_15[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)  (None, 64, 64, 64)   0           conv2d_16[0][0]                  
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 64, 64, 96)   0           up_sampling2d_3[0][0]            
                                                                 conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 64, 64, 32)   27680       concatenate_3[0][0]              
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 64, 64, 32)   9248        conv2d_17[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D)  (None, 128, 128, 32) 0           conv2d_18[0][0]                  
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 128, 128, 48) 0           up_sampling2d_4[0][0]            
                                                                 conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 128, 128, 16) 6928        concatenate_4[0][0]              
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 128, 128, 16) 2320        conv2d_19[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_4 (Conv2DTrans (None, 261, 261, 8)  6280        conv2d_20[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_5 (Conv2DTrans (None, 267, 267, 4)  1572        conv2d_transpose_4[0][0]         
__________________________________________________________________________________________________
conv2d_transpose_6 (Conv2DTrans (None, 300, 300, 2)  9250        conv2d_transpose_5[0][0]         
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 300, 300, 1)  3           conv2d_transpose_6[0][0]         
==================================================================================================
Total params: 1,980,358
Trainable params: 1,980,358
Non-trainable params: 0

我從預訓練的 model 中獲取了weights_list = model1.get_weights() ,這個列表的長度是 54。我只是無法理解層中權重的索引。 Model1 有 43 層,索引讓我很困惑。 是否有任何通用方法來理解其他模型的索引? 在未來的模型中,我將在模型之間選擇特定的層權重。

您需要從特定圖層中提取權重。 單獨使用 model.get_weights() 會返回整個 model 的權重。

以下應該工作

model1_weights = model1.get_layer('conv2d_6').get_weights()
model2.get_layer('conv2d_3').set_weights(model1_weights)

編輯:要將其擴展到所有所需的層,最簡單的方法是為要傳輸的模型 1 創建層名稱列表 (mode1_layers),並為要傳輸的模型 2 (mode2_layers) 創建層名稱列表至。 然后將其轉移到 for 循環中,如下所示。

for i in range(len(model1_layers)):
    layer_weights = model1.get_layer(model1_layers[i]).get_weights()
    model2.get_layer(model2_layers[i]).set_weights(layer_weights)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM