How do I "Flatten" a model in a model in Keras


from keras.layers import Input, Dense, BatchNormalization, Subtract
from keras.models import Model

input_layer = Input((10, ))
x = Dense(5)(input_layer)
output_layer = Dense(10)(x)

model = Model(input_layer, output_layer)


Layer (type)                 Output Shape              Param #
input_1 (InputLayer)         (None, 10)                0
dense_1 (Dense)              (None, 5)                 55
dense_2 (Dense)              (None, 10)                60
Total params: 115
Trainable params: 115
Non-trainable params: 0


# Add preprocessing layers
new_input = Input((10,))
x = BatchNormalization()(new_input)
model.layers.pop(0) # remove original input
x = model(x)
# Change the model to residual modeling with a subtract layer
new_output = Subtract()([new_input, x])
new_model = Model(new_input, new_output)


Layer (type)                    Output Shape         Param #     Connected to
input_2 (InputLayer)            (None, 10)           0
batch_normalization_1 (BatchNor (None, 10)           40          input_2[0][0]
model_1 (Model)                 (None, 10)           115         batch_normalization_1[0][0]
subtract_1 (Subtract)           (None, 10)           0           input_2[0][0]
Total params: 155
Trainable params: 135
Non-trainable params: 20

這意味着在內部,不是將第一個模型的每個層都添加到新模型的new_model.layers中,而是將整個模型作為單個元素添加到new_model.layers列表中。 我希望它實際上使單個元素變平,因此摘要將看起來像這樣:

Layer (type)                    Output Shape         Param #     Connected to
input_2 (InputLayer)            (None, 10)           0
batch_normalization_1 (BatchNor (None, 10)           40          input_2[0][0]
dense_1 (Dense)                 (None, 5)            55          batch_normalization_1[0][0]
dense_2 (Dense)                 (None, 10)           60          dense_1[0]
subtract_1 (Subtract)           (None, 10)           0           input_2[0][0]
Total params: 155
Trainable params: 135
Non-trainable params: 20

我為什么在乎? 基本上,我正在嘗試一些不同的建模方法,在這些方法中,我可以嘗試將預處理和后處理以及不同的基礎模型進行不同的組合,並觀察它們如何影響我的結果。 但是,當有關基本模型及其參數的信息全部包裝在僅包含有關預處理和后期處理信息的包裝模型的單個“模型”層中時,進行比較分析確實很困難。

這是Keras的正確行為,因為Model實際上是從Layer繼承的。 您可以將基本模型包裝到一個函數中,以阻止其包裝到Model

def base_model(input=None):
  input_layer = input or Input((10, ))
  x = Dense(5)(input_layer)
  output_layer = Dense(10)(x)
  if input is None:
    return Model(input_layer, output_layer)
  return output_layer
# Then use it
# Add preprocessing layers
new_input = Input((10,))
x = BatchNormalization()(new_input)
x = base_model(x)
# Change the model to residual modeling with a subtract layer
new_output = Subtract()([new_input, x])
new_model = Model(new_input, new_output)



