Hi I've have one pretrained network model1 as follows:
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 3000, 200, 1) 0
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 3000, 200, 1) 26 input_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 1500, 200, 1) 0 conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 1500, 200, 1) 26 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 750, 200, 1) 0 conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 750, 200, 1) 26 max_pooling2d_2[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D) (None, 375, 200, 1) 0 conv2d_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 375, 200, 1) 26 max_pooling2d_3[0][0]
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D) (None, 187, 199, 1) 0 conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 128, 128, 1) 4321 max_pooling2d_4[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 128, 128, 16) 160 conv2d_5[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 128, 128, 16) 2320 conv2d_6[0][0]
__________________________________________________________________________________________________
max_pooling2d_5 (MaxPooling2D) (None, 64, 64, 16) 0 conv2d_7[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 64, 64, 32) 4640 max_pooling2d_5[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 64, 64, 32) 9248 conv2d_8[0][0]
__________________________________________________________________________________________________
max_pooling2d_6 (MaxPooling2D) (None, 32, 32, 32) 0 conv2d_9[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 32, 32, 64) 18496 max_pooling2d_6[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 32, 32, 64) 36928 conv2d_10[0][0]
__________________________________________________________________________________________________
max_pooling2d_7 (MaxPooling2D) (None, 16, 16, 64) 0 conv2d_11[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 16, 16, 128) 73856 max_pooling2d_7[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 16, 16, 128) 147584 conv2d_12[0][0]
__________________________________________________________________________________________________
max_pooling2d_8 (MaxPooling2D) (None, 8, 8, 128) 0 conv2d_13[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, 8, 8, 256) 295168 max_pooling2d_8[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D) (None, 8, 8, 256) 590080 conv2d_14[0][0]
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 16, 16, 256) 0 conv2d_15[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 16, 16, 384) 0 up_sampling2d_1[0][0]
conv2d_13[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D) (None, 16, 16, 128) 442496 concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D) (None, 16, 16, 128) 147584 conv2d_16[0][0]
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D) (None, 32, 32, 128) 0 conv2d_17[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 32, 32, 192) 0 up_sampling2d_2[0][0]
conv2d_11[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D) (None, 32, 32, 64) 110656 concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_19 (Conv2D) (None, 32, 32, 64) 36928 conv2d_18[0][0]
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D) (None, 64, 64, 64) 0 conv2d_19[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 64, 64, 96) 0 up_sampling2d_3[0][0]
conv2d_9[0][0]
__________________________________________________________________________________________________
conv2d_20 (Conv2D) (None, 64, 64, 32) 27680 concatenate_3[0][0]
__________________________________________________________________________________________________
conv2d_21 (Conv2D) (None, 64, 64, 32) 9248 conv2d_20[0][0]
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D) (None, 128, 128, 32) 0 conv2d_21[0][0]
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 128, 128, 48) 0 up_sampling2d_4[0][0]
conv2d_7[0][0]
__________________________________________________________________________________________________
conv2d_22 (Conv2D) (None, 128, 128, 16) 6928 concatenate_4[0][0]
__________________________________________________________________________________________________
conv2d_23 (Conv2D) (None, 128, 128, 16) 2320 conv2d_22[0][0]
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 261, 261, 8) 6280 conv2d_23[0][0]
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 267, 267, 4) 1572 conv2d_transpose_1[0][0]
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 300, 300, 2) 9250 conv2d_transpose_2[0][0]
__________________________________________________________________________________________________
conv2d_24 (Conv2D) (None, 300, 300, 1) 3 conv2d_transpose_3[0][0]
==================================================================================================
Total params: 1,983,850
Trainable params: 1,983,850
Non-trainable params: 0
I want to take weights in model1 from conv2d_6 to end and load them into the model2 (mentioned below) from conv2d_3 to end.
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 502, 200, 3) 0
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 502, 200, 3) 30 input_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 251, 200, 3) 0 conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 251, 200, 3) 48 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 125, 100, 3) 0 conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 126, 109, 3) 183 max_pooling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 127, 118, 3) 183 conv2d_transpose_1[0][0]
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 128, 128, 3) 201 conv2d_transpose_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 128, 128, 16) 448 conv2d_transpose_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 128, 128, 16) 2320 conv2d_3[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D) (None, 64, 64, 16) 0 conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 64, 64, 32) 4640 max_pooling2d_3[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 64, 64, 32) 9248 conv2d_5[0][0]
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D) (None, 32, 32, 32) 0 conv2d_6[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 32, 32, 64) 18496 max_pooling2d_4[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 32, 32, 64) 36928 conv2d_7[0][0]
__________________________________________________________________________________________________
max_pooling2d_5 (MaxPooling2D) (None, 16, 16, 64) 0 conv2d_8[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 16, 16, 128) 73856 max_pooling2d_5[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 16, 16, 128) 147584 conv2d_9[0][0]
__________________________________________________________________________________________________
max_pooling2d_6 (MaxPooling2D) (None, 8, 8, 128) 0 conv2d_10[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 8, 8, 256) 295168 max_pooling2d_6[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 8, 8, 256) 590080 conv2d_11[0][0]
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 16, 16, 256) 0 conv2d_12[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 16, 16, 384) 0 up_sampling2d_1[0][0]
conv2d_10[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 16, 16, 128) 442496 concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, 16, 16, 128) 147584 conv2d_13[0][0]
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D) (None, 32, 32, 128) 0 conv2d_14[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 32, 32, 192) 0 up_sampling2d_2[0][0]
conv2d_8[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D) (None, 32, 32, 64) 110656 concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D) (None, 32, 32, 64) 36928 conv2d_15[0][0]
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D) (None, 64, 64, 64) 0 conv2d_16[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 64, 64, 96) 0 up_sampling2d_3[0][0]
conv2d_6[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D) (None, 64, 64, 32) 27680 concatenate_3[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D) (None, 64, 64, 32) 9248 conv2d_17[0][0]
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D) (None, 128, 128, 32) 0 conv2d_18[0][0]
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 128, 128, 48) 0 up_sampling2d_4[0][0]
conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_19 (Conv2D) (None, 128, 128, 16) 6928 concatenate_4[0][0]
__________________________________________________________________________________________________
conv2d_20 (Conv2D) (None, 128, 128, 16) 2320 conv2d_19[0][0]
__________________________________________________________________________________________________
conv2d_transpose_4 (Conv2DTrans (None, 261, 261, 8) 6280 conv2d_20[0][0]
__________________________________________________________________________________________________
conv2d_transpose_5 (Conv2DTrans (None, 267, 267, 4) 1572 conv2d_transpose_4[0][0]
__________________________________________________________________________________________________
conv2d_transpose_6 (Conv2DTrans (None, 300, 300, 2) 9250 conv2d_transpose_5[0][0]
__________________________________________________________________________________________________
conv2d_21 (Conv2D) (None, 300, 300, 1) 3 conv2d_transpose_6[0][0]
==================================================================================================
Total params: 1,980,358
Trainable params: 1,980,358
Non-trainable params: 0
I have taken weights_list = model1.get_weights()
from pre-trained model and the length of this list is 54. I just can't understand the indexing of the weights in layers. Model1 has 43 layers and indexing is confusing me. Is there any general way to understanding indexing for other models as well? In future models I will be selecting specific layer weights inbetween the models.
You need to extract the weights from the specific layer. Using model.get_weights() by itself returns the weights for the entire model.
The following should work
model1_weights = model1.get_layer('conv2d_6').get_weights()
model2.get_layer('conv2d_3').set_weights(model1_weights)
Edit: To extend this to all the desired layers the simplest thing to do will be to create a list of layer names (mode1_layers) for model1 that you wish to transfer and a list of layer names for model2 (mode2_layers) that you want to transfer to. Then transfer this in a for loop as below.
for i in range(len(model1_layers)):
layer_weights = model1.get_layer(model1_layers[i]).get_weights()
model2.get_layer(model2_layers[i]).set_weights(layer_weights)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.