Input layer for VGG16 in Keras

I am building a U-Net and I'd like to use pre-trained model (VGG16) for the decoder part.

The challenge is that I have grayscale images, while VGG works with RGB.

I have found a function to convert it to RGB (by concatenating):

from keras.layers import Layer
from keras import backend as K

class Gray2VGGInput(Layer):
    """Custom conversion layer"""
    def build(self, x):
        self.image_mean = K.variable(value=np.array([103.939, 116.779, 123.68]).reshape([1,1,1,3]).astype('float32'), 
                                     name='imageNet_mean' )
        self.built = True
    def call(self, x):
        rgb_x = K.concatenate([x,x,x], axis=-1 )
        norm_x = rgb_x - self.image_mean
        return norm_x

    def compute_output_shape(self, input_shape):
        return input_shape[:3] + (3,)

But I fail to plug it into the model. The Gray2VGGInput is a layer, so I am looking for a way how to connect this layer to those from VGG . Below is my attempt:

def UNET1_VGG16():
    UNET with pretrained layers from VGG16
    def upsampleLayer(in_layer, concat_layer, input_size):
        Upsampling (=Decoder) layer building block
        in_layer: input layer
        concat_layer: layer with which to concatenate
        input_size: input size fot convolution
        upsample = Conv2DTranspose(input_size, (2, 2), strides=(2, 2), padding='same')(in_layer)    
        upsample = concatenate([upsample, concat_layer])
        conv = Conv2D(input_size, (1, 1), activation='relu', kernel_initializer='he_normal', padding='same')(upsample)
        conv = BatchNormalization()(conv)
        conv = Dropout(0.2)(conv)
        conv = Conv2D(input_size, (1, 1), activation='relu', kernel_initializer='he_normal', padding='same')(conv)
        conv = BatchNormalization()(conv)
        return conv

    img_rows = 864
    img_cols = 1232

    #batch, height, width, channels
    inputs_1 = Input((img_rows, img_cols, 1))
    inputs_3 = Input((img_rows, img_cols, 3))

    #VGG16 BASE
    #Prepare net
    base_VGG16 = VGG16(input_tensor=inputs_3, 

    #This is the problematic part

    vgg_inputs_3 = Gray2VGGInput(name='gray_to_rgb')(inputs_1)

    model_input = Model(inputs=[inputs_1], outputs=[vgg_inputs_3])

    new_outputs = base_VGG16(model_input.output)
    new_inputs = Model(inputs_1, new_outputs)

    c1 = base_VGG16.get_layer("block1_conv2").output #(None, 864, 1232, 64)
    c2 = base_VGG16.get_layer("block2_conv2").output #(None, 432, 616, 128) 
    c3 = base_VGG16.get_layer("block3_conv2").output #(None, 216, 308, 256) 
    c4 = base_VGG16.get_layer("block4_conv2").output #(None, 108, 154, 512) 

    c5 = base_VGG16.get_layer("block5_conv2").output #(None, 54, 77, 512)

    c6 = upsampleLayer(in_layer=c5, concat_layer=c4, input_size=512)
    c7 = upsampleLayer(in_layer=c6, concat_layer=c3, input_size=256)
    c8 = upsampleLayer(in_layer=c7, concat_layer=c2, input_size=128)
    c9 = upsampleLayer(in_layer=c8, concat_layer=c1, input_size=64)

    outputs = Conv2D(1, (1, 1), activation='sigmoid')(c9)

    model = Model(inputs=[new_inputs.input], outputs=[outputs])

    #Freeze layers
    for layer in model.layers[:16]:
        layer.trainable = False



    return model 

I get following error:

ValueError: Graph disconnected: cannot obtain value for tensor Tensor("input_14:0", shape=(?, 864, 1232, 3), dtype=float32) at layer "input_14". The following previous layers were accessed without issue: []

I think you do not need multiple inputs, rather pass your Gray2VGGInput layer output as the input to the VGG16 model. I think how you get output tensors from the VGG16 model is alright. Here is something I can suggest:

from keras.applications import VGG16

inputs_1 = Input(shape=(img_rows, img_cols, 1))
inputs_3 = Gray2VGGInput(name='gray_to_rgb')(inputs_1)  #shape=(img_rows, img_cols, 3)
base_VGG16 = VGG16(include_top=False, weights='imagenet', input_tensor=inputs_3)

c1 = base_VGG16.get_layer("block1_conv2").output #(None, 864, 1232, 64)
c2 = base_VGG16.get_layer("block2_conv2").output #(None, 432, 616, 128) 
c3 = base_VGG16.get_layer("block3_conv2").output #(None, 216, 308, 256) 
c4 = base_VGG16.get_layer("block4_conv2").output #(None, 108, 154, 512) 

c5 = base_VGG16.get_layer("block5_conv2").output #(None, 54, 77, 512)
... and so on

The model can be called as

model = Model(inputs=inputs_1, outputs=outputs)

You can give this a try and let me know if it works. I haven't tested it so there might be mistakes.


model_input = Model(inputs=[inputs_1], outputs=[vgg_inputs_3]) 


model_input = Model(inputs=[vgg_inputs_3] etc...

