簡體   English   中英

從預訓練的 model 中移除頂層,遷移學習,tensorflow (load_model)

[英]Remove top layer from pre-trained model, transfer learning, tensorflow (load_model)

我已經預訓練了一個帶有兩個類的 model(我自己保存的模型),我想用它來進行遷移學習來訓練一個有六個類的 model。 我已將預訓練的 model 加載到新的訓練腳本中:

base_model = tf.keras.models.load_model("base_model_path")

如何刪除頂層/頭層(conv1D 層)?

我看到在 keras 中可以使用 base_model.pop(),而對於 tf.keras.applications 可以簡單地使用include_top=false但是在使用 tf.Z063009BB15C8272BD0C701CF0 和 load_ZDF 時有類似的東西嗎?

(我嘗試過這樣的事情:

for layer in base_model.layers[:-1]:
    layer.trainable = False`

然后將其添加到新的 model (?)但我不確定如何繼續)

謝謝你的幫助!

你可以嘗試這樣的事情:

基礎 model 由一個簡單的Conv1D網絡組成,該網絡具有一個 output 層和兩個類:

import tensorflow as tf

samples = 100
timesteps = 5
features = 2
classes = 2
dummy_x, dummy_y = tf.random.normal((100, 5, 2)), tf.random.uniform((100, 1), maxval=2, dtype=tf.int32)

base_model = tf.keras.Sequential()
base_model.add(tf.keras.layers.Conv1D(32, 3, activation='relu', input_shape=(5, 2)))
base_model.add(tf.keras.layers.GlobalMaxPool1D())
base_model.add(tf.keras.layers.Dense(32, activation='relu'))
base_model.add( tf.keras.layers.Dense(classes, activation='softmax'))

base_model.compile(optimizer='adam', loss = tf.keras.losses.SparseCategoricalCrossentropy())
print(base_model.summary())
base_model.fit(dummy_x, dummy_y, batch_size=16, epochs=1)
base_model.save("base_model")
base_model = tf.keras.models.load_model("base_model")
Model: "sequential_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv1d_31 (Conv1D)          (None, 3, 32)             224       
                                                                 
 global_max_pooling1d_13 (Gl  (None, 32)               0         
 obalMaxPooling1D)                                               
                                                                 
 dense_17 (Dense)            (None, 32)                1056      
                                                                 
 dense_18 (Dense)            (None, 2)                 66        
                                                                 
=================================================================
Total params: 1,346
Trainable params: 1,346
Non-trainable params: 0
_________________________________________________________________
None
7/7 [==============================] - 0s 3ms/step - loss: 0.6973
INFO:tensorflow:Assets written to: base_model/assets

新的 model 也是由一個簡單的Conv1D網絡組成,但有一個包含六個類的 output 層。 它還包含base_model的所有層,除了第一個Conv1D層和最后一個 output 層:

classes = 6
dummy_x, dummy_y = tf.random.normal((100, 5, 2)), tf.random.uniform((100, 1), maxval=6, dtype=tf.int32)
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv1D(64, 3, activation='relu', input_shape=(5, 2)))
model.add(tf.keras.layers.Conv1D(32, 2, activation='relu'))
for layer in base_model.layers[1:-1]: # Skip first and last layer
  model.add(layer)
model.add(tf.keras.layers.Dense(classes, activation='softmax'))
model.compile(optimizer='adam', loss = tf.keras.losses.SparseCategoricalCrossentropy())
print(model.summary())
model.fit(dummy_x, dummy_y, batch_size=16, epochs=1)
Model: "sequential_9"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv1d_32 (Conv1D)          (None, 3, 64)             448       
                                                                 
 conv1d_33 (Conv1D)          (None, 2, 32)             4128      
                                                                 
 global_max_pooling1d_13 (Gl  (None, 32)               0         
 obalMaxPooling1D)                                               
                                                                 
 dense_17 (Dense)            (None, 32)                1056      
                                                                 
 dense_19 (Dense)            (None, 6)                 198       
                                                                 
=================================================================
Total params: 5,830
Trainable params: 5,830
Non-trainable params: 0
_________________________________________________________________
None
7/7 [==============================] - 0s 3ms/step - loss: 1.8069
<keras.callbacks.History at 0x7f90c87a3c50>

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM