簡體   English   中英

如何在本地或全局保存訓練好的神經網絡權重?

[英]How to save trained neural network weights locally or globally?

我目前正在使用TensorFlow源來保存和恢復經過訓練的 NN model 權重:

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Create a new model instance
model = create_model()

# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')

我對培訓期間的檢查點也很熟悉,但我的問題是:

我們可以在訓練 model 時在本地或全局保存模型/權重,而不是將其保存到文件中嗎?

我正在使用網格搜索之類的東西,但我有一個循環,在每次迭代中,我在數據集的某些部分上訓練我的 model,然后保存訓練/學習的權重並繼續在另一組數據集上訓練/學習?

我工作的示例偽代碼:

for i in range(1,10):
    - use dataset A1 for training
    - train model on dataset A1
    - test on the testing dataset X
    - save model weights
    - restore model weights
    - now use dataset A2
    - run model on trained weights to see initial accuracy
    - retrain the model on dataset A2 and keep previously saved weights
    - save model weights
end

我已經看過這樣的另一篇文章,但它沒有回答我的問題。

是的你可以。 這樣做的方法是創建一個自定義回調。 在回調中,創建了一個名為 best_weights 的 class 變量 如果驗證損失是迄今為止產生的最低損失,則該 class 變量在每個時期結束時更新。 下面是代碼。

class LRA(keras.callbacks.Callback):
    def __init__(self,model, verbose ):
        super(LRA, self).__init__()
        self.model=model
        self.lowest_vloss=np.inf # set lowest validation loss to infinity        
        best_weights=self.model.get_weights() # set a class vaiable so weights can be loaded after training is completed           
         
    def on_epoch_end(self, epoch, logs=None):  # method runs on the end of each epoch
        v_loss=logs.get('val_loss')  # get the validation loss for this epoch
        if v_loss< self.lowest_vloss: # check if the validation loss improved
            if verbose==1:
                msg=f' validation loss improved from {self.lowest_vloss:8.5f} to {v_loss:8.5}, saving best weights' 
                print (msg)
            self.lowest_vloss=v_loss # replace lowest validation loss with new validation loss                
            LRA.best_weights=self.model.get_weights() # validation loss improved so save the weights

在 model.fit 中包含回調 =[LRA(model, verbose=1)]

回調使 class 變量 LRA.best_weights 可用。 它包含實現最低驗證損失的時期的 model 權重。 您可以在例如 model.set_weights(LRA.best_weights) 中使用它。 在回調參數 model 是你的 model。 參數詳細是 integer。 如果設置為 1,則在 epoch 結束時,如果驗證損失有所改善,則會打印一條消息,表明已保存最佳權重。 如果詳細不是 = 1,則不打印任何消息。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM