簡體   English   中英

Keras:如何部分加載權重?

[英]Keras: how to load weights partially?

如何部分加載 model 權重? 例如,我想使用原始 imagenet 權重( vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5 )僅加載block1VGG19

def VGG19_part(input_shape=None):
    img_input = tf.keras.layers.Input(shape=input_shape)

    # Block 1
    x = tf.keras.layers.Conv2D(64, (3, 3),
                      activation='linear',
                      padding='same',
                      name='block1_conv1')(img_input)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.Conv2D(64, (3, 3),
                      activation='linear',
                      padding='same',
                      name='block1_conv2')(x)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)

    model = tf.keras.Model(img_input, x, name='vgg19')

    model.load_weights('/Users/myuser/.keras/models/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5')

    print(model.summary())

    return model

此代碼產生錯誤: ValueError: You are trying to load a weight file containing 16 layers into a model with 2 layers.

Keras 應用程序模塊中的 vgg19 默認具有 imagenet 的權重,因此我使用它來加載我們對自定義 model 感興趣的權重

input_shape = (224,224,3)

full_vgg19 = tf.keras.applications.VGG19(include_top=False, weights='imagenet', input_shape=input_shape)

def VGG19_part(full_vgg19, input_shape=None):
    
    img_input = tf.keras.layers.Input(shape=input_shape)

    # Block 1
    x = tf.keras.layers.Conv2D(64, (3, 3),
                      activation='linear',
                      padding='same',
                      name='block1_conv1')(img_input)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.Conv2D(64, (3, 3),
                      activation='linear',
                      padding='same',
                      name='block1_conv2')(x)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)

    model = tf.keras.Model(img_input, x, name='vgg19')
    model.set_weights(full_vgg19.get_weights()[:4])
    
    return model

part_vgg19 = VGG19_part(full_vgg19, input_shape)

### check if the weights/bias are the same:
[(i == j).all() for i,j in zip(part_vgg19.get_weights()[:4],full_vgg19.get_weights()[:4])] # True True True True

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM