簡體   English   中英

Keras:如何保存模型或重量?

[英]Keras: How to save models or weights?

如果這個問題看起來很簡單,我很抱歉。 但是閱讀 Keras 保存和恢復幫助頁面:

https://www.tensorflow.org/beta/tutorials/keras/save_and_restore_models

我不明白如何在訓練期間使用“ModelCheckpoint”進行保存。 幫助文件提到它應該提供 3 個文件,我只看到一個,MODEL.ckpt。

這是我的代碼:

checkpoint_dir = FolderName + "/tmp/model.ckpt"
cp_callback = k.callbacks.ModelCheckpoint(checkpoint_dir,verbose=1,save_weights_only=True)    
parallel_model.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate),loss=my_cost_MSE, metrics=['accuracy])
    parallel _model.fit(image, annotation, epochs=epoch, 
    batch_size=batch_size, steps_per_epoch=10,
                                 validation_data=(image_val,annotation_val),validation_steps=num_batch_val,callbacks=callbacks_list)

另外,當我想在訓練后加載權重時:

model = k.models.load_model(file_checkpoint)

我得到錯誤:

"raise ValueError('Unknown ' + printable_module_name + ':' + object_name) 
ValueError: Unknown loss function:my_cost_MSE"

my-cost_MSE 是我在訓練中使用的成本 function。

首先,看起來您正在使用tf.keras (來自 tensorflow)實現而不是keras (來自 keras-team/keras repo)。 在這種情況下,如tf.keras 指南中所述

保存模型的權重時,tf.keras 默認為檢查點格式。 通過 save_format='h5' 使用 HDF5。

另一方面,請注意,添加回調ModelCheckpoint通常大致相當於在每個 epoch 結束時調用model.save(...) ,所以這就是為什么您應該期望保存三個文件(根據檢查點格式)。

它不這樣做的原因是,通過使用選項save_weights_only=True ,您只保存了權重。 大致相當於在每個 epoch 結束時為model.save_weights替換對model.save的調用。 因此,唯一保存的文件是具有權重的文件。

從這里,您可以通過兩種不同的方式進行:

僅存儲權重

您需要預先加載您的模型(例如結構),然后調用model.load_weights而不是keras.models.load_model

model = MyModel(...)  # Your model definition as used in training
model.load_weights(file_checkpoint)

請注意,在這種情況下,自定義定義 ( my_cost_MSE ) 不會出現問題,因為您只是在加載模型權重。

存儲整個模型

另一種方法是存儲整個模型並相應地加載它:

cp_callback = k.callbacks.ModelCheckpoint(
    checkpoint_dir,verbose=1,
    save_weights_only=False
)    
parallel_model.compile(
    optimizer=tf.keras.optimizers.Adam(lr=learning_rate),
    loss=my_cost_MSE,
    metrics=['accuracy']
)

model.fit(..., callbacks=[cp_callback])

然后你可以通過以下方式加載它:

model = k.models.load_model(file_checkpoint, custom_objects={"my_cost_MSE": my_cost_MSE})

請注意,在后一種情況下,您需要指定custom_objects因為反序列化模型需要它的定義。

keras有一個save命令。 它保存了重建模型所需的所有細節。

(來自keras 文檔

from keras.models import load_model
model.save('my_model.h5')  # creates a HDF5 file 'my_model.h5'
del model  # deletes the existing model

# returns am identical compiled model
model = load_model('my_model.h5')

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM