[英]How do I “Flatten” a model in a model in Keras
所以可以說我有一些自己喜歡的Keras模型:
from keras.layers import Input, Dense, BatchNormalization, Subtract
from keras.models import Model
input_layer = Input((10, ))
x = Dense(5)(input_layer)
output_layer = Dense(10)(x)
model = Model(input_layer, output_layer)
我可以使用summary()
函數了解我的模型,然后得到:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 10) 0
_________________________________________________________________
dense_1 (Dense) (None, 5) 55
_________________________________________________________________
dense_2 (Dense) (None, 10) 60
=================================================================
Total params: 115
Trainable params: 115
Non-trainable params: 0
_________________________________________________________________
現在,我想嘗試向模型添加一些預處理和后期處理步驟,例如,我可以執行以下操作:
# Add preprocessing layers
new_input = Input((10,))
x = BatchNormalization()(new_input)
model.layers.pop(0) # remove original input
x = model(x)
# Change the model to residual modeling with a subtract layer
new_output = Subtract()([new_input, x])
new_model = Model(new_input, new_output)
但是現在當我調用summary()
我僅了解預處理和后期處理層,而不是原始模型:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_2 (InputLayer) (None, 10) 0
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 10) 40 input_2[0][0]
__________________________________________________________________________________________________
model_1 (Model) (None, 10) 115 batch_normalization_1[0][0]
__________________________________________________________________________________________________
subtract_1 (Subtract) (None, 10) 0 input_2[0][0]
model_1[1][0]
==================================================================================================
Total params: 155
Trainable params: 135
Non-trainable params: 20
__________________________________________________________________________________________________
這意味着在內部,不是將第一個模型的每個層都添加到新模型的new_model.layers
中,而是將整個模型作為單個元素添加到new_model.layers
列表中。 我希望它實際上使單個元素變平,因此摘要將看起來像這樣:
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_2 (InputLayer) (None, 10) 0
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 10) 40 input_2[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 5) 55 batch_normalization_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 10) 60 dense_1[0]
__________________________________________________________________________________________________
subtract_1 (Subtract) (None, 10) 0 input_2[0][0]
dense_2[1][0]
==================================================================================================
Total params: 155
Trainable params: 135
Non-trainable params: 20
__________________________________________________________________________________________________
我為什么在乎? 基本上,我正在嘗試一些不同的建模方法,在這些方法中,我可以嘗試將預處理和后處理以及不同的基礎模型進行不同的組合,並觀察它們如何影響我的結果。 但是,當有關基本模型及其參數的信息全部包裝在僅包含有關預處理和后期處理信息的包裝模型的單個“模型”層中時,進行比較分析確實很困難。
這是Keras的正確行為,因為Model
實際上是從Layer
繼承的。 您可以將基本模型包裝到一個函數中,以阻止其包裝到Model
:
def base_model(input=None):
input_layer = input or Input((10, ))
x = Dense(5)(input_layer)
output_layer = Dense(10)(x)
if input is None:
return Model(input_layer, output_layer)
return output_layer
# Then use it
# Add preprocessing layers
new_input = Input((10,))
x = BatchNormalization()(new_input)
x = base_model(x)
# Change the model to residual modeling with a subtract layer
new_output = Subtract()([new_input, x])
new_model = Model(new_input, new_output)
如果傳遞現有圖層,則不會包裝,如果不傳遞,則將返回Model
實例。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.