简体   繁体   中英

2x nested Tensorflow custom layers results in zero trainable parameters

I am creating a series of custom Tensorflow (version 2.4.1 ) layers and am running into a problem where the model summary shows zero trainable parameters. Below is a series of examples showing how everything is fine until I add in the last custom layer.

Here are the imports and custom classes:

from tensorflow.keras.models import Model
from tensorflow.keras.layers import (BatchNormalization, Conv2D, Input, ReLU, 
                                     Layer)

class basic_conv_stack(Layer):
    def __init__(self, filters, kernel_size, strides):
        super(basic_conv_stack, self).__init__()
        self.conv1 = Conv2D(filters, kernel_size, strides, padding='same')
        self.bn1 = BatchNormalization()
        self.relu = ReLU()

    def call(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        return x
    
class basic_residual(Layer):
    def __init__(self, filters, kernel_size, strides):
        super(basic_residual, self).__init__()
        self.bcs1 = basic_conv_stack(filters, kernel_size, strides)
        self.bcs2 = basic_conv_stack(filters, kernel_size, strides)

    def call(self, x):
        x = self.bcs1(x)
        x = self.bcs2(x)
        return x
    
class basic_module(Layer):
    def __init__(self, filters, kernel_size, strides):
        super(basic_module, self).__init__()
        self.res = basic_residual
        self.args = (filters, kernel_size, strides)
    
    def call(self, x):
        for _ in range(4):
            x = self.res(*self.args)(x)
        return x

Now, if I do the following, everything works out ok and I get 300 trainable parameters:

input_layer = Input((128, 128, 3))
conv = basic_conv_stack(10, 3, 1)(input_layer)

model = Model(input_layer, conv)
print (model.summary())

Similarly, if I do the following, I get 1,230 trainable parameters:

input_layer = Input((128, 128, 3))
conv = basic_residual(10, 3, 1)(input_layer)

model = Model(input_layer, conv)
print (model.summary())

However, if I try the basic_module class, I get zero trainable parameters:

input_layer = Input((128, 128, 3))
conv = basic_module(10, 3, 1)(input_layer)

model = Model(input_layer, conv)
print (model.summary())

Does anyone know why this is happening?

Edit to add:

I discovered that the layers used in the call must be initialized in the class's init for things to work properly. So if I change the basic module to this:

class basic_module(Layer):
    def __init__(self, filters, kernel_size, strides):
        super(basic_module, self).__init__()
        self.clayers = [basic_residual(filters, kernel_size, strides) for _ in range(4)]

    def call(self, x):
        for idx in range(4):
            x = self.clayers[idx](x)
        return x

Everything works fine. I don't know why this is the case, so I'll leave this question open in case someone can answer the why of this question.

You have to initialize the class instances with the required parameter such as filters , kernel_size , strides to the predefined base_mdoule . Also, note that these hyper-parameters are related to trainable weights properties.

# >>> a = basic_module
# >>> a __main__.basic_module
# >>> a = basic_module(10, 3, 1)
# >>> a 
# >>> <__main__.basic_module at 0x7f6123eed510>

class basic_module(Layer):
    def __init__(self, filters, kernel_size, strides):
        super(basic_module, self).__init__()
        self.res = basic_residual # < ---
        self.args = (filters, kernel_size, strides)
    
    def call(self, x):
        for _ in range(4):
            x = self.res(*self.args)(x)
        return x

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM