简体   繁体   English

通过部分加载权重进行迁移学习

[英]Transfer learning by loading weights partially

Hi I've have one pretrained network model1 as follows:嗨,我有一个预训练的网络模型 1,如下所示:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 3000, 200, 1) 0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 3000, 200, 1) 26          input_1[0][0]                    
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 1500, 200, 1) 0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 1500, 200, 1) 26          max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 750, 200, 1)  0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 750, 200, 1)  26          max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 375, 200, 1)  0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 375, 200, 1)  26          max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 187, 199, 1)  0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 128, 128, 1)  4321        max_pooling2d_4[0][0]            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 128, 128, 16) 160         conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 128, 128, 16) 2320        conv2d_6[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_5 (MaxPooling2D)  (None, 64, 64, 16)   0           conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 64, 64, 32)   4640        max_pooling2d_5[0][0]            
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 64, 64, 32)   9248        conv2d_8[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_6 (MaxPooling2D)  (None, 32, 32, 32)   0           conv2d_9[0][0]                   
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 32, 32, 64)   18496       max_pooling2d_6[0][0]            
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 32, 32, 64)   36928       conv2d_10[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_7 (MaxPooling2D)  (None, 16, 16, 64)   0           conv2d_11[0][0]                  
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 16, 16, 128)  73856       max_pooling2d_7[0][0]            
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 16, 16, 128)  147584      conv2d_12[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_8 (MaxPooling2D)  (None, 8, 8, 128)    0           conv2d_13[0][0]                  
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 8, 8, 256)    295168      max_pooling2d_8[0][0]            
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 8, 8, 256)    590080      conv2d_14[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 16, 16, 256)  0           conv2d_15[0][0]                  
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 16, 16, 384)  0           up_sampling2d_1[0][0]            
                                                                 conv2d_13[0][0]                  
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 16, 16, 128)  442496      concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 16, 16, 128)  147584      conv2d_16[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 32, 32, 128)  0           conv2d_17[0][0]                  
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 32, 32, 192)  0           up_sampling2d_2[0][0]            
                                                                 conv2d_11[0][0]                  
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 32, 32, 64)   110656      concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 32, 32, 64)   36928       conv2d_18[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)  (None, 64, 64, 64)   0           conv2d_19[0][0]                  
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 64, 64, 96)   0           up_sampling2d_3[0][0]            
                                                                 conv2d_9[0][0]                   
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 64, 64, 32)   27680       concatenate_3[0][0]              
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 64, 64, 32)   9248        conv2d_20[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D)  (None, 128, 128, 32) 0           conv2d_21[0][0]                  
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 128, 128, 48) 0           up_sampling2d_4[0][0]            
                                                                 conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_22 (Conv2D)              (None, 128, 128, 16) 6928        concatenate_4[0][0]              
__________________________________________________________________________________________________
conv2d_23 (Conv2D)              (None, 128, 128, 16) 2320        conv2d_22[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 261, 261, 8)  6280        conv2d_23[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 267, 267, 4)  1572        conv2d_transpose_1[0][0]         
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 300, 300, 2)  9250        conv2d_transpose_2[0][0]         
__________________________________________________________________________________________________
conv2d_24 (Conv2D)              (None, 300, 300, 1)  3           conv2d_transpose_3[0][0]         
==================================================================================================
Total params: 1,983,850
Trainable params: 1,983,850
Non-trainable params: 0

I want to take weights in model1 from conv2d_6 to end and load them into the model2 (mentioned below) from conv2d_3 to end.我想从 conv2d_6 获取 model1 中的权重到 end 并将它们加载到从 conv2d_3 到 end 的 model2(如下所述)中。

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 502, 200, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 502, 200, 3)  30          input_1[0][0]                    
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 251, 200, 3)  0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 251, 200, 3)  48          max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 125, 100, 3)  0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 126, 109, 3)  183         max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 127, 118, 3)  183         conv2d_transpose_1[0][0]         
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 128, 128, 3)  201         conv2d_transpose_2[0][0]         
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 128, 128, 16) 448         conv2d_transpose_3[0][0]         
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 128, 128, 16) 2320        conv2d_3[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 64, 64, 16)   0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 64, 64, 32)   4640        max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 64, 64, 32)   9248        conv2d_5[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 32, 32, 32)   0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 32, 32, 64)   18496       max_pooling2d_4[0][0]            
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 32, 32, 64)   36928       conv2d_7[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_5 (MaxPooling2D)  (None, 16, 16, 64)   0           conv2d_8[0][0]                   
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 16, 16, 128)  73856       max_pooling2d_5[0][0]            
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 16, 16, 128)  147584      conv2d_9[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_6 (MaxPooling2D)  (None, 8, 8, 128)    0           conv2d_10[0][0]                  
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 8, 8, 256)    295168      max_pooling2d_6[0][0]            
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 8, 8, 256)    590080      conv2d_11[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 16, 16, 256)  0           conv2d_12[0][0]                  
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 16, 16, 384)  0           up_sampling2d_1[0][0]            
                                                                 conv2d_10[0][0]                  
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 16, 16, 128)  442496      concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 16, 16, 128)  147584      conv2d_13[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 32, 32, 128)  0           conv2d_14[0][0]                  
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 32, 32, 192)  0           up_sampling2d_2[0][0]            
                                                                 conv2d_8[0][0]                   
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 32, 32, 64)   110656      concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 32, 32, 64)   36928       conv2d_15[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)  (None, 64, 64, 64)   0           conv2d_16[0][0]                  
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 64, 64, 96)   0           up_sampling2d_3[0][0]            
                                                                 conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 64, 64, 32)   27680       concatenate_3[0][0]              
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 64, 64, 32)   9248        conv2d_17[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D)  (None, 128, 128, 32) 0           conv2d_18[0][0]                  
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 128, 128, 48) 0           up_sampling2d_4[0][0]            
                                                                 conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 128, 128, 16) 6928        concatenate_4[0][0]              
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 128, 128, 16) 2320        conv2d_19[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_4 (Conv2DTrans (None, 261, 261, 8)  6280        conv2d_20[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_5 (Conv2DTrans (None, 267, 267, 4)  1572        conv2d_transpose_4[0][0]         
__________________________________________________________________________________________________
conv2d_transpose_6 (Conv2DTrans (None, 300, 300, 2)  9250        conv2d_transpose_5[0][0]         
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 300, 300, 1)  3           conv2d_transpose_6[0][0]         
==================================================================================================
Total params: 1,980,358
Trainable params: 1,980,358
Non-trainable params: 0

I have taken weights_list = model1.get_weights() from pre-trained model and the length of this list is 54. I just can't understand the indexing of the weights in layers.我从预训练的 model 中获取了weights_list = model1.get_weights() ,这个列表的长度是 54。我只是无法理解层中权重的索引。 Model1 has 43 layers and indexing is confusing me. Model1 有 43 层,索引让我很困惑。 Is there any general way to understanding indexing for other models as well?是否有任何通用方法来理解其他模型的索引? In future models I will be selecting specific layer weights inbetween the models.在未来的模型中,我将在模型之间选择特定的层权重。

You need to extract the weights from the specific layer.您需要从特定图层中提取权重。 Using model.get_weights() by itself returns the weights for the entire model.单独使用 model.get_weights() 会返回整个 model 的权重。

The following should work以下应该工作

model1_weights = model1.get_layer('conv2d_6').get_weights()
model2.get_layer('conv2d_3').set_weights(model1_weights)

Edit: To extend this to all the desired layers the simplest thing to do will be to create a list of layer names (mode1_layers) for model1 that you wish to transfer and a list of layer names for model2 (mode2_layers) that you want to transfer to.编辑:要将其扩展到所有所需的层,最简单的方法是为要传输的模型 1 创建层名称列表 (mode1_layers),并为要传输的模型 2 (mode2_layers) 创建层名称列表至。 Then transfer this in a for loop as below.然后将其转移到 for 循环中,如下所示。

for i in range(len(model1_layers)):
    layer_weights = model1.get_layer(model1_layers[i]).get_weights()
    model2.get_layer(model2_layers[i]).set_weights(layer_weights)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 如何在没有 ImageNet 权重的情况下进行迁移学习? - How to do Transfer Learning without ImageNet weights? 转学习中的问题 - 无法使用keras传递权重 - Problem in transfer learning - couldn't transfer weights using keras TensorFLow 迁移学习加载 TFRecordDataset - TensorFLow Transfer Learning loading TFRecordDataset 迁移学习后如何查看预训练模型的权重? - How to view weights of a pretrained model after transfer learning? Package 训练神经网络的 model 权重,使其可用于迁移学习 - Package the model weights of a trained neural network to make it usable for transfer learning 尝试为迁移学习任务加载我自己的权重时出现 ValueError - Im getting ValueError when trying to load my own weights for a transfer learning task 如何使用现有CNN模型中的预训练权重在Keras中进行迁移学习? - How can I use pre-trained weights from an existing CNN model for transfer learning in Keras? Keras:如何部分加载权重? - Keras: how to load weights partially? 如何使用来自一个预先训练的MLP的最后一个隐藏层权重作为Keras的新MLP(转移学习)的输入? - How to use the last hidden layer weights from one pre-trained MLP as input to a new MLP (transfer learning) with Keras? 张量流迁移学习 - Tensor flow transfer learning
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM