简体   繁体   中英

How to set weights for specific channels in keras with pre-trained weights?

I am trying to achieve fine tuning on a Resnet50 architecture (I built mine based on the keras implementation) with pre-trained weights provided by Keras.

The drawback of this pre-trained model is it has been trained on images with tree channels. In my case inputs have more than three channels. It can be 5, 6,...

That channel variation implies that the first conv1 layer is dependent of the number of channels. Thus to use pre-trained weights I have two possibilities.

  1. Load weights after the conv1 layer and for layers before conv1 they are set as random.

  2. The second possibility is to set conv1 with RGB weights and fill the remaining channels with a replication of the RGB weights.

I tried the second possibility but it is only working with multiple of 3. Moreover, if I want specific initializers (like glorot_uniform for instance) instead of duplicating bands it seems to be impossible.

So I would like to know If there are some functions or others approaches than mine to achieve such thing especially to work with any number of channels instead of multiple of 3?

Note: Before applying the second possibility I tried to find functions to achieve that but I did not find anything.

def ResNet50(load_weights=True,
             input_shape=None,
             include_top=False,
             classes=100):
    img_input = Input(shape=input_shape, name='tuned_input')
    x = ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input)

    # Stage 1 (conv1_x)
    x = Conv2D(64, (7, 7),
               strides=(2, 2),
               padding='valid',
               kernel_initializer=KERNEL_INIT,
               name='tuned_conv1')(x)

    x = BatchNormalization(axis=CHANNEL_AXIS, name='bn_conv1')(x)
    x = Activation('relu')(x)
    x = ZeroPadding2D(padding=(1, 1), name='pool1_pad')(x)
    x = MaxPooling2D((3, 3), strides=(2, 2))(x)

    # Stage 2 (conv2_x)
    x = _convolution_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
    for block in ['b', 'c']:
        x = _identity_block(x, 3, [64, 64, 256], stage=2, block=block)

    # Stage 3 (conv3_x)
    x = _convolution_block(x, 3, [128, 128, 512], stage=3, block='a')
    for block in ['b', 'c', 'd']:
        x = _identity_block(x, 3, [128, 128, 512], stage=3, block=block)

    # Stage 4 (conv4_x)
    x = _convolution_block(x, 3, [256, 256, 1024], stage=4, block='a')
    for block in ['b', 'c', 'd', 'e', 'f']:
        x = _identity_block(x, 3, [256, 256, 1024], stage=4, block=block)

    # Stage 5 (conv5_x)
    x = _convolution_block(x, 3, [512, 512, 2048], stage=5, block='a')
    for block in ['b', 'c']:
        x = _identity_block(x, 3, [512, 512, 2048], stage=5, block=block)

    # AVGPOOL
    x = AveragePooling2D((2, 2), name="avg_pool")(x)
    if include_top:
        # output layer
        x = Flatten()(x)
        x = Dense(classes, activation='softmax', name='fc' + str(classes), kernel_initializer=KERNEL_INIT)(x)

    inputs = img_input
    # Create model.
    model = models.Model(inputs, x, name='resnet50')

    if load_weights:
        weights_path = get_file(
            'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5',
            WEIGHTS_PATH_NO_TOP,
            cache_subdir='models',
            md5_hash='a268eb855778b3df3c7506639542a6af')
        model.load_weights(weights_path, by_name=True)
        # Set weights for conv1 for 6 channels
        f = h5py.File(weights_path, 'r')
        d = f['conv1']
        model.get_layer('tuned_conv1').set_weights([d['conv1_W_1:0'][:].repeat(2, axis=2), d['conv1_b_1:0']])

    return model

# example image 50x50 with 6 channels
ResNet50(input_shape=(50,50,6))

Your ResNet model can process input images with 3 channels (eg RGB images). Now you have an image which may have any number of channels. One way to overcome this, is to replicate each channel of your input image 3 times, process each of the replicated channels independently using the model, and then concatenate the results (which are actually the feature-maps of final layer in the model). Here is a sketch of this approach:

from keras import backend as K
from keras.layers import Input, Lambda, concatenate

inp = Input(shape=(w, h, num_channels))
rep_c = Lambda(lambda x: K.repeat_elements(K.expand_dims(x, axis=-1), 3, -1))

out_maps = []
for i in range(num_channels):
    out_maps.append(resnet_model(rep_c[:,:,:,i]))

concat = concatenate(out_maps)

# the rest of the model goes here...

But note that depending on the data you have and the problem you are working on, this approach may or may not work with a good accuracy, ie if you are not sure then you need to experiment to find out.

You can just copy the resnet50 model to the local and change the channel number to whatever you need by revising the model.

I believe at least in the case of not using the pre-trained weight it should work, which is in my case to use it for different applications.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM