简体   繁体   中英

Defining model blocks in tf.keras

I'm experimenting with my model's architecture and I would like to have several predefined blocks of layers that I could mix at will. I thought that creating a different class for each of this block structure would make it easier, and I figured that subclassing the Model class in tf.keras was the way to go. So I have done the following (Toy example, yet long. Sorry.).

class PoolingBlock(Model):
    def __init__(self, filters, stride, name):
        super(PoolingBlock, self).__init__(name=name)

        self.bn = BatchNormalization()
        self.conv1 = Conv1D(filters=filters, kernel_size=1, padding='same')
        self.mp1 = MaxPooling1D(stride, padding='same')

    def call(self, input_tensor, training=False, mask=None):
        x = self.bn(input_tensor)
        x = tf.nn.relu(x)
        x = self.conv1(x)
        x = self.mp1(x)
        return x

class ModelA(Model):
    def __init__(self, n_dense, filters, stride, name):
        super(ModelA, self).__init__(name=name)

        self.d1 = Dense(n_dense, "DenseLayer1")
        self.pb1 = PoolingBlock(filters, stride, name="PoolingBlock_1")
        self.d2 = Dense(n_dense, "DenseLayer2")

    def call(self, inputs, training=False, mask=None):
        x = inputs
        x = self.d1(x)
        x = self.pb1(x)
        x = self.d2(x)
        return x

model = ModelA(100, 10, 2, 'ModelA')
model.build(input_shape=x.shape)

Then I continue with model.compile(...) and model.fit(...) as usual. But when training, I receive this warning:

WARNING:tensorflow:Entity < bound method PoolingBlock.call of < model.PoolingBlock object at 0x7fe09ca04208 > > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, export AUTOGRAPH_VERBOSITY=10 ) and attach the full output. Cause: converting < bound method PoolingBlock.call of < model.PoolingBlock object at 0x7fe09ca04208 > >: AttributeError: module 'gast' has no attribute 'Num'

I don't understand what that means. I am wondering if my training is going as I have planned, if this way of subclassing is correct and solid, if I can suppress this warning somehow.

Kindly try to downgrade the version of gast

pip install gast==0.2.2

And then re-train the network

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM