简体   繁体   English

2x 嵌套 Tensorflow 自定义层导致零可训练参数

[英]2x nested Tensorflow custom layers results in zero trainable parameters

I am creating a series of custom Tensorflow (version 2.4.1 ) layers and am running into a problem where the model summary shows zero trainable parameters.我正在创建一系列自定义 Tensorflow (版本2.4.1 )层,并且遇到了 model 摘要显示零可训练参数的问题。 Below is a series of examples showing how everything is fine until I add in the last custom layer.下面是一系列示例,显示在我添加最后一个自定义层之前一切都很好。

Here are the imports and custom classes:以下是导入和自定义类:

from tensorflow.keras.models import Model
from tensorflow.keras.layers import (BatchNormalization, Conv2D, Input, ReLU, 
                                     Layer)

class basic_conv_stack(Layer):
    def __init__(self, filters, kernel_size, strides):
        super(basic_conv_stack, self).__init__()
        self.conv1 = Conv2D(filters, kernel_size, strides, padding='same')
        self.bn1 = BatchNormalization()
        self.relu = ReLU()

    def call(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        return x
    
class basic_residual(Layer):
    def __init__(self, filters, kernel_size, strides):
        super(basic_residual, self).__init__()
        self.bcs1 = basic_conv_stack(filters, kernel_size, strides)
        self.bcs2 = basic_conv_stack(filters, kernel_size, strides)

    def call(self, x):
        x = self.bcs1(x)
        x = self.bcs2(x)
        return x
    
class basic_module(Layer):
    def __init__(self, filters, kernel_size, strides):
        super(basic_module, self).__init__()
        self.res = basic_residual
        self.args = (filters, kernel_size, strides)
    
    def call(self, x):
        for _ in range(4):
            x = self.res(*self.args)(x)
        return x

Now, if I do the following, everything works out ok and I get 300 trainable parameters:现在,如果我执行以下操作,一切正常,我得到300 个可训练参数:

input_layer = Input((128, 128, 3))
conv = basic_conv_stack(10, 3, 1)(input_layer)

model = Model(input_layer, conv)
print (model.summary())

Similarly, if I do the following, I get 1,230 trainable parameters:同样,如果我执行以下操作,我会得到 1,230 个可训练参数:

input_layer = Input((128, 128, 3))
conv = basic_residual(10, 3, 1)(input_layer)

model = Model(input_layer, conv)
print (model.summary())

However, if I try the basic_module class, I get zero trainable parameters:但是,如果我尝试使用 basic_module class,我得到的可训练参数为零:

input_layer = Input((128, 128, 3))
conv = basic_module(10, 3, 1)(input_layer)

model = Model(input_layer, conv)
print (model.summary())

Does anyone know why this is happening?有谁知道为什么会这样?

Edit to add:编辑添加:

I discovered that the layers used in the call must be initialized in the class's init for things to work properly.我发现调用中使用的层必须在类的 init 中初始化才能正常工作。 So if I change the basic module to this:因此,如果我将基本模块更改为:

class basic_module(Layer):
    def __init__(self, filters, kernel_size, strides):
        super(basic_module, self).__init__()
        self.clayers = [basic_residual(filters, kernel_size, strides) for _ in range(4)]

    def call(self, x):
        for idx in range(4):
            x = self.clayers[idx](x)
        return x

Everything works fine.一切正常。 I don't know why this is the case, so I'll leave this question open in case someone can answer the why of this question.我不知道为什么会这样,所以我会留下这个问题,以防有人可以回答这个问题的原因。

You have to initialize the class instances with the required parameter such as filters , kernel_size , strides to the predefined base_mdoule .您必须初始化class 实例,使用所需的参数,例如filterskernel_sizebase_mdoule strides Also, note that these hyper-parameters are related to trainable weights properties.另外,请注意,这些超参数与可训练的权重属性有关。

# >>> a = basic_module
# >>> a __main__.basic_module
# >>> a = basic_module(10, 3, 1)
# >>> a 
# >>> <__main__.basic_module at 0x7f6123eed510>

class basic_module(Layer):
    def __init__(self, filters, kernel_size, strides):
        super(basic_module, self).__init__()
        self.res = basic_residual # < ---
        self.args = (filters, kernel_size, strides)
    
    def call(self, x):
        for _ in range(4):
            x = self.res(*self.args)(x)
        return x

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM