简体   繁体   中英

Share weights for a block of layers in keras

In this other question , it was shown that one can reuse a Dense layer on different Input layers to enable weight sharing. I am now wondering how to extend this principle to an entire block of layers; my attempt is as follows:

from keras.layers import Input, Dense, BatchNormalization, PReLU
from keras.initializers import Constant
from keras import backend as K

def embedding_block(dim):
    dense = Dense(dim, activation=None, kernel_initializer='glorot_normal')
    activ = PReLU(alpha_initializer=Constant(value=0.25))(dense)
    bnorm = BatchNormalization()(activ)
    return bnorm
def embedding_stack():
    return embedding_block(32)(embedding_block(16)(embedding_block(8)))
                                                                           
common_embedding = embedding_stack()

Here I am creating "embedding blocks" with a single dense layer of variable dimension which I'm trying to string together into an "embedding stack", made of blocks with increasing dimension. Then I would like to apply this "common embedding" to several Input layers (that all have the same shape) such that the weights are shared.

The above code fails with

<ipython-input-33-835f06ed7bbb> in embedding_block(dim)
      1 def embedding_block(dim):
      2     dense = Dense(dim, activation=None, kernel_initializer='glorot_normal')
----> 3     activ = PReLU(alpha_initializer=Constant(value=0.25))(dense)
      4     bnorm = BatchNormalization()(activ)
      5     return bnorm

/localenv/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
    980       with ops.name_scope_v2(name_scope):
    981         if not self.built:
--> 982           self._maybe_build(inputs)
    983 
    984         with ops.enable_auto_cast_variables(self._compute_dtype_object):

/localenv/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _maybe_build(self, inputs)
   2641         # operations.
   2642         with tf_utils.maybe_init_scope(self):
-> 2643           self.build(input_shapes)  # pylint:disable=not-callable
   2644       # We must set also ensure that the layer is marked as built, and the build
   2645       # shape is stored since user defined build functions may not be calling

/localenv/lib/python3.8/site-packages/tensorflow/python/keras/utils/tf_utils.py in wrapper(instance, input_shape)
    321     if input_shape is not None:
    322       input_shape = convert_shapes(input_shape, to_tuples=True)
--> 323     output_shape = fn(instance, input_shape)
    324     # Return shapes from `fn` as TensorShapes.
    325     if output_shape is not None:

/localenv/lib/python3.8/site-packages/tensorflow/python/keras/layers/advanced_activations.py in build(self, input_shape)
    138   @tf_utils.shape_type_conversion
    139   def build(self, input_shape):
--> 140     param_shape = list(input_shape[1:])
    141     if self.shared_axes is not None:
    142       for i in self.shared_axes:

TypeError: 'NoneType' object is not subscriptable

What is the proper way to do this? Thanks!

You need to dissociate layers instantiation from the model creation.

Here is a simple method using a for loop:

from tensorflow.keras import layers, initializers, Model

def embedding_block(dim):
    dense = layers.Dense(dim, activation=None, kernel_initializer='glorot_normal')
    activ = layers.PReLU(alpha_initializer=initializers.Constant(value=0.25))
    bnorm = layers.BatchNormalization()
    return [dense, activ, bnorm]

stack = embedding_block(8) + embedding_block(16) + embedding_block(32)

inp1 = layers.Input((5,))
inp2 = layers.Input((5,))

x,y = inp1,inp2
for layer in stack:
    x = layer(x)
    y = layer(y)

concat_layer = layers.Concatenate()([x,y])
pred = layers.Dense(1, activation="sigmoid")(concat_layer)

model = Model(inputs = [inp1, inp2], outputs=pred)

We first create each layer, and then iterate through them using the functional API to create the model.

You can analyze the network in netron to see that the weights are indeed shared:

Netron 网络可视化。顶部有两个输入使用相同的层

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM