![](/img/trans.png)
[英]_SymbolicException when using intermediate model outputs in tensorflow keras
[英]keras model using intermediate layers as inputs and outputs
我在 Keras(Tensoflow 后端)中有一個基本的 LSTM 自動編碼器。 該模型的結構如下:
l0 = Input(shape=(10, 2))
l1 = LSTM(16, activation='relu', return_sequences=True)(l0)
l2 = LSTM(8, activation='relu', return_sequences=False)(l1)
l3 = RepeatVector(10)(l2)
l4 = LSTM(8, activation='relu', return_sequences=True)(l3)
l5 = LSTM(16, activation='relu', return_sequences=True)(l4)
l6 = TimeDistributed(Dense(2))(l5)
我可以按如下方式提取和編譯編碼器和自動編碼器:
encoder = Model(l0, l2)
auto_encoder = Model(l0, l6)
auto_encoder.compile(optimizer='rmsprop', loss='mse', metrics=['mse'])
但是,當我嘗試使用中間層制作模型時,例如:
decoder = Model(inputs=l3, outputs=l6)
我收到以下錯誤:
ValueError: Graph disconnected: cannot obtain value for tensor Tensor("input_12:0", shape=(?, 10, 2), dtype=float32) at layer "input_12". The following previous layers were accessed without issue: []
我不明白l3
和l6
是如何相互斷開的! 我還嘗試使用get_layer(...).input
和get_layer(...).output
制作解碼器,但它拋出了同樣的錯誤。
一個解釋對我很有幫助。
問題是您嘗試創建的模型沒有輸入層:
decoder = Model(inputs=l3, outputs=l6)
您可以通過生成一個具有正確形狀的新Input()
圖層然后訪問每個現有圖層來創建一個。 像這樣的東西:
input_layer = Input(shape=(8,))
l3 = auto_encoder.layers[3](input_layer)
l4 = auto_encoder.layers[4](l3)
l5 = auto_encoder.layers[5](l4)
l6 = auto_encoder.layers[6](l5)
decoder = Model(input_layer, l6)
decoder.summary()
Model: "model_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_14 (InputLayer) [(None, 8)] 0
_________________________________________________________________
repeat_vector_2 (RepeatVecto (None, 10, 8) 0
_________________________________________________________________
lstm_12 (LSTM) (None, 10, 8) 544
_________________________________________________________________
lstm_13 (LSTM) (None, 10, 16) 1600
_________________________________________________________________
time_distributed_1 (TimeDist (None, 10, 2) 34
=================================================================
Total params: 2,178
Trainable params: 2,178
Non-trainable params: 0
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.