简体   繁体   中英

How do I “Flatten” a model in a model in Keras

So lets say I have some Keras model I have built that I like:

from keras.layers import Input, Dense, BatchNormalization, Subtract
from keras.models import Model

input_layer = Input((10, ))
x = Dense(5)(input_layer)
output_layer = Dense(10)(x)

model = Model(input_layer, output_layer)

I can learn about my model with the summary() function and I get:

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         (None, 10)                0
_________________________________________________________________
dense_1 (Dense)              (None, 5)                 55
_________________________________________________________________
dense_2 (Dense)              (None, 10)                60
=================================================================
Total params: 115
Trainable params: 115
Non-trainable params: 0
_________________________________________________________________

Now I want to try out adding some pre and post processing steps to the model so for example I might do the following:

# Add preprocessing layers
new_input = Input((10,))
x = BatchNormalization()(new_input)
model.layers.pop(0) # remove original input
x = model(x)
# Change the model to residual modeling with a subtract layer
new_output = Subtract()([new_input, x])
new_model = Model(new_input, new_output)

However now when I call summary() I only learn about the pre and post processing layers, not my original model:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_2 (InputLayer)            (None, 10)           0
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 10)           40          input_2[0][0]
__________________________________________________________________________________________________
model_1 (Model)                 (None, 10)           115         batch_normalization_1[0][0]
__________________________________________________________________________________________________
subtract_1 (Subtract)           (None, 10)           0           input_2[0][0]
                                                                 model_1[1][0]
==================================================================================================
Total params: 155
Trainable params: 135
Non-trainable params: 20
__________________________________________________________________________________________________

This means that internally, rather than having each layer from the first model get added to the new_model.layers of the new model, instead the whole model is getting added as a single element to the new_model.layers list. I want it to actually have that single element get flattened out so the summary will look more like this:


Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_2 (InputLayer)            (None, 10)           0
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 10)           40          input_2[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 5)            55          batch_normalization_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 10)           60          dense_1[0]
__________________________________________________________________________________________________
subtract_1 (Subtract)           (None, 10)           0           input_2[0][0]
                                                                 dense_2[1][0]
==================================================================================================
Total params: 155
Trainable params: 135
Non-trainable params: 20
__________________________________________________________________________________________________

Why do I care? Basically I am trying out some different modeling approaches where I can try out different combinations of pre and post processing along with different base models and see how that impacts my results. However it is really hard to do the comparative analysis when the information about the base model and its parameters are all wrapped up in a single "model" layer of the wrapper model that just holds information about the pre and post processing.

This is the correct behaviour of Keras since Model actually inherits from Layer . You can wrap the base model into a function to stop it from wrapping into a Model :

def base_model(input=None):
  input_layer = input or Input((10, ))
  x = Dense(5)(input_layer)
  output_layer = Dense(10)(x)
  if input is None:
    return Model(input_layer, output_layer)
  return output_layer
# Then use it
# Add preprocessing layers
new_input = Input((10,))
x = BatchNormalization()(new_input)
x = base_model(x)
# Change the model to residual modeling with a subtract layer
new_output = Subtract()([new_input, x])
new_model = Model(new_input, new_output)

If you pass an existing layer it will not wrap, if you don't it will return you a Model instance.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM