I need to initialize custom Conv2D kernels with weights
W = a1b1 + a2b2 + ... + anbn
where W = custom weight of Conv2D layer to initialise with
a = random weight Tensors as keras.backend.variable(np.random.uniform())
, shape=(64, 1, 10)
b = fixed basis filters defined as keras.backend.constant(...)
, shape=(10, 11, 11)
W = K.sum(a[:, :, :, None, None] * b[None, None, :, :, :], axis=2) #shape=(64, 1, 11, 11)
I want my model to update the 'W' values with only changing the 'a's while keeping the 'b's constant.
I pass the custom 'W's as
Conv2D(64, kernel_size=(11, 11), activation='relu', kernel_initializer=kernel_init_L1)(img)
where kernel_init_L1
returns keras.backend.variable(K.reshape(w_L1, (11, 11, 1, 64)))
Problem:
I am not sure if this is the correct way to do this. Is it possible to specify in Keras which ones are trainable
and which are not. I know that layers can be set trainable = True
but i am not sure about weights.
I think the implementation is incorrect because I get similar results from my model with or without the custom initializations.
It would be immensely helpful if someone can point out any mistakes in my approach or provide a way to verify it.
Warning about your shapes: If your kernel size is (11,11)
, and assuming you have 64 input channels and 1 output channel, your final kernel shape must be (11,11,64,1)
.
You should probably be going for a[None,None]
and b[:,:,:,None,None]
.
class CustomConv2D(Conv2D):
def __init__(self, filters, kernel_size, kernelB = None, **kwargs):
super(CustomConv2D, self).__init__(filters, kernel_size,**kwargs)
self.kernelB = kernelB
def build(self, input_shape):
#use the input_shape to calculate the shapes of A and B
#if needed, pay attention to the "data_format" used.
#this is an actual weight, because it uses `self.add_weight`
self.kernelA = self.add_weight(
shape=shape_of_kernel_A + (1,1), #or (1,1) + shape_of_A
initializer='glorot_uniform', #or select another
name='kernelA',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
#this is an ordinary var that will participate in the calculation
#not a weight, not updated
if self.kernelB is None:
self.kernelB = K.constant(....)
#use the shape already containing the new axes
#in the original conv layer, this property would be the actual kernel,
#now it's just a var that will be used in the original's "call" method
self.kernel = K.sum(self.kernelA * self.kernelB, axis=2)
#important: the resulting shape should be:
#(kernelSizeX, kernelSizeY, input_channels, output_channels)
#the following are remains of the original code for "build" in Conv2D
#use_bias is True by default
if self.use_bias:
self.bias = self.add_weight(shape=(self.filters,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
# Set input spec.
self.input_spec = InputSpec(ndim=self.rank + 2,
axes={channel_axis: input_dim})
self.built = True
When you create a custom layer from zero (derived from Layer
), you should have these methods:
__init__(self, ... parameters ...)
- this is the creator, it's called when you create a new instance of your layer. Here, you store the values the user passed as parameters. (In a Conv2D, the init would have the "filters", "kernel_size", etc.) build(self, input_shape)
- this is where you should create the weights (all learnable vars are created here, based on the input shape) compute_output_shape(self,input_shape)
- here you return the output shape based on the input shape call(self,inputs)
- Here you perform the actual layer calculations Since we're not creating this layer from zero, but deriving it from Conv2D
, everything is ready, all we did was to "change" the build method and replace what would be considered the kernel of the Conv2D layer.
More on custom layers: https://keras.io/layers/writing-your-own-keras-layers/
The call
method for conv layers is here in class _Conv(Layer):
.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.