[英]How to save trained neural network weights locally or globally?
我目前正在使用TensorFlow源來保存和恢復經過訓練的 NN model 權重:
# Save the weights
model.save_weights('./checkpoints/my_checkpoint')
# Create a new model instance
model = create_model()
# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')
我對培訓期間的檢查點也很熟悉,但我的問題是:
我們可以在訓練 model 時在本地或全局保存模型/權重,而不是將其保存到文件中嗎?
我正在使用網格搜索之類的東西,但我有一個循環,在每次迭代中,我在數據集的某些部分上訓練我的 model,然后保存訓練/學習的權重並繼續在另一組數據集上訓練/學習?
我工作的示例偽代碼:
for i in range(1,10):
- use dataset A1 for training
- train model on dataset A1
- test on the testing dataset X
- save model weights
- restore model weights
- now use dataset A2
- run model on trained weights to see initial accuracy
- retrain the model on dataset A2 and keep previously saved weights
- save model weights
end
我已經看過這樣的另一篇文章,但它沒有回答我的問題。
是的你可以。 這樣做的方法是創建一個自定義回調。 在回調中,創建了一個名為 best_weights 的 class 變量 如果驗證損失是迄今為止產生的最低損失,則該 class 變量在每個時期結束時更新。 下面是代碼。
class LRA(keras.callbacks.Callback):
def __init__(self,model, verbose ):
super(LRA, self).__init__()
self.model=model
self.lowest_vloss=np.inf # set lowest validation loss to infinity
best_weights=self.model.get_weights() # set a class vaiable so weights can be loaded after training is completed
def on_epoch_end(self, epoch, logs=None): # method runs on the end of each epoch
v_loss=logs.get('val_loss') # get the validation loss for this epoch
if v_loss< self.lowest_vloss: # check if the validation loss improved
if verbose==1:
msg=f' validation loss improved from {self.lowest_vloss:8.5f} to {v_loss:8.5}, saving best weights'
print (msg)
self.lowest_vloss=v_loss # replace lowest validation loss with new validation loss
LRA.best_weights=self.model.get_weights() # validation loss improved so save the weights
在 model.fit 中包含回調 =[LRA(model, verbose=1)]
回調使 class 變量 LRA.best_weights 可用。 它包含實現最低驗證損失的時期的 model 權重。 您可以在例如 model.set_weights(LRA.best_weights) 中使用它。 在回調參數 model 是你的 model。 參數詳細是 integer。 如果設置為 1,則在 epoch 結束時,如果驗證損失有所改善,則會打印一條消息,表明已保存最佳權重。 如果詳細不是 = 1,則不打印任何消息。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.