简体   繁体   中英

Transfer learning by loading weights partially

Hi I've have one pretrained network model1 as follows:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 3000, 200, 1) 0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 3000, 200, 1) 26          input_1[0][0]                    
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 1500, 200, 1) 0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 1500, 200, 1) 26          max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 750, 200, 1)  0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 750, 200, 1)  26          max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 375, 200, 1)  0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 375, 200, 1)  26          max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 187, 199, 1)  0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 128, 128, 1)  4321        max_pooling2d_4[0][0]            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 128, 128, 16) 160         conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 128, 128, 16) 2320        conv2d_6[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_5 (MaxPooling2D)  (None, 64, 64, 16)   0           conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 64, 64, 32)   4640        max_pooling2d_5[0][0]            
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 64, 64, 32)   9248        conv2d_8[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_6 (MaxPooling2D)  (None, 32, 32, 32)   0           conv2d_9[0][0]                   
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 32, 32, 64)   18496       max_pooling2d_6[0][0]            
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 32, 32, 64)   36928       conv2d_10[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_7 (MaxPooling2D)  (None, 16, 16, 64)   0           conv2d_11[0][0]                  
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 16, 16, 128)  73856       max_pooling2d_7[0][0]            
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 16, 16, 128)  147584      conv2d_12[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_8 (MaxPooling2D)  (None, 8, 8, 128)    0           conv2d_13[0][0]                  
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 8, 8, 256)    295168      max_pooling2d_8[0][0]            
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 8, 8, 256)    590080      conv2d_14[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 16, 16, 256)  0           conv2d_15[0][0]                  
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 16, 16, 384)  0           up_sampling2d_1[0][0]            
                                                                 conv2d_13[0][0]                  
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 16, 16, 128)  442496      concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 16, 16, 128)  147584      conv2d_16[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 32, 32, 128)  0           conv2d_17[0][0]                  
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 32, 32, 192)  0           up_sampling2d_2[0][0]            
                                                                 conv2d_11[0][0]                  
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 32, 32, 64)   110656      concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 32, 32, 64)   36928       conv2d_18[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)  (None, 64, 64, 64)   0           conv2d_19[0][0]                  
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 64, 64, 96)   0           up_sampling2d_3[0][0]            
                                                                 conv2d_9[0][0]                   
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 64, 64, 32)   27680       concatenate_3[0][0]              
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 64, 64, 32)   9248        conv2d_20[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D)  (None, 128, 128, 32) 0           conv2d_21[0][0]                  
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 128, 128, 48) 0           up_sampling2d_4[0][0]            
                                                                 conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_22 (Conv2D)              (None, 128, 128, 16) 6928        concatenate_4[0][0]              
__________________________________________________________________________________________________
conv2d_23 (Conv2D)              (None, 128, 128, 16) 2320        conv2d_22[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 261, 261, 8)  6280        conv2d_23[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 267, 267, 4)  1572        conv2d_transpose_1[0][0]         
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 300, 300, 2)  9250        conv2d_transpose_2[0][0]         
__________________________________________________________________________________________________
conv2d_24 (Conv2D)              (None, 300, 300, 1)  3           conv2d_transpose_3[0][0]         
==================================================================================================
Total params: 1,983,850
Trainable params: 1,983,850
Non-trainable params: 0

I want to take weights in model1 from conv2d_6 to end and load them into the model2 (mentioned below) from conv2d_3 to end.

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 502, 200, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 502, 200, 3)  30          input_1[0][0]                    
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 251, 200, 3)  0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 251, 200, 3)  48          max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 125, 100, 3)  0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 126, 109, 3)  183         max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 127, 118, 3)  183         conv2d_transpose_1[0][0]         
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 128, 128, 3)  201         conv2d_transpose_2[0][0]         
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 128, 128, 16) 448         conv2d_transpose_3[0][0]         
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 128, 128, 16) 2320        conv2d_3[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 64, 64, 16)   0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 64, 64, 32)   4640        max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 64, 64, 32)   9248        conv2d_5[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 32, 32, 32)   0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 32, 32, 64)   18496       max_pooling2d_4[0][0]            
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 32, 32, 64)   36928       conv2d_7[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_5 (MaxPooling2D)  (None, 16, 16, 64)   0           conv2d_8[0][0]                   
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 16, 16, 128)  73856       max_pooling2d_5[0][0]            
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 16, 16, 128)  147584      conv2d_9[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_6 (MaxPooling2D)  (None, 8, 8, 128)    0           conv2d_10[0][0]                  
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 8, 8, 256)    295168      max_pooling2d_6[0][0]            
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 8, 8, 256)    590080      conv2d_11[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 16, 16, 256)  0           conv2d_12[0][0]                  
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 16, 16, 384)  0           up_sampling2d_1[0][0]            
                                                                 conv2d_10[0][0]                  
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 16, 16, 128)  442496      concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 16, 16, 128)  147584      conv2d_13[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 32, 32, 128)  0           conv2d_14[0][0]                  
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 32, 32, 192)  0           up_sampling2d_2[0][0]            
                                                                 conv2d_8[0][0]                   
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 32, 32, 64)   110656      concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 32, 32, 64)   36928       conv2d_15[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)  (None, 64, 64, 64)   0           conv2d_16[0][0]                  
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 64, 64, 96)   0           up_sampling2d_3[0][0]            
                                                                 conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 64, 64, 32)   27680       concatenate_3[0][0]              
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 64, 64, 32)   9248        conv2d_17[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D)  (None, 128, 128, 32) 0           conv2d_18[0][0]                  
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 128, 128, 48) 0           up_sampling2d_4[0][0]            
                                                                 conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 128, 128, 16) 6928        concatenate_4[0][0]              
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 128, 128, 16) 2320        conv2d_19[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_4 (Conv2DTrans (None, 261, 261, 8)  6280        conv2d_20[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_5 (Conv2DTrans (None, 267, 267, 4)  1572        conv2d_transpose_4[0][0]         
__________________________________________________________________________________________________
conv2d_transpose_6 (Conv2DTrans (None, 300, 300, 2)  9250        conv2d_transpose_5[0][0]         
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 300, 300, 1)  3           conv2d_transpose_6[0][0]         
==================================================================================================
Total params: 1,980,358
Trainable params: 1,980,358
Non-trainable params: 0

I have taken weights_list = model1.get_weights() from pre-trained model and the length of this list is 54. I just can't understand the indexing of the weights in layers. Model1 has 43 layers and indexing is confusing me. Is there any general way to understanding indexing for other models as well? In future models I will be selecting specific layer weights inbetween the models.

You need to extract the weights from the specific layer. Using model.get_weights() by itself returns the weights for the entire model.

The following should work

model1_weights = model1.get_layer('conv2d_6').get_weights()
model2.get_layer('conv2d_3').set_weights(model1_weights)

Edit: To extend this to all the desired layers the simplest thing to do will be to create a list of layer names (mode1_layers) for model1 that you wish to transfer and a list of layer names for model2 (mode2_layers) that you want to transfer to. Then transfer this in a for loop as below.

for i in range(len(model1_layers)):
    layer_weights = model1.get_layer(model1_layers[i]).get_weights()
    model2.get_layer(model2_layers[i]).set_weights(layer_weights)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM