簡體   English   中英

在每N個時期結束時保存模型權重

[英]save model weights at the end of every N epochs

我正在訓練NN,並希望在預測階段每N個時期保存模型權重。 我提出這個草案代碼,它的靈感來自於@grovina 在這里的回應。 請你提出建議嗎? 提前致謝。

from keras.callbacks import Callback

class WeightsSaver(Callback):
    def __init__(self, model, N):
        self.model = model
        self.N = N
        self.epoch = 0

    def on_batch_end(self, epoch, logs={}):
        if self.epoch % self.N == 0:
            name = 'weights%08d.h5' % self.epoch
            self.model.save_weights(name)
        self.epoch += 1

然后將其添加到fit調用:每5個時期保存一次權重:

model.fit(X_train, Y_train, callbacks=[WeightsSaver(model, 5)])

您不應該為回調傳遞模型。 它已經通過它的超級訪問模型。 所以刪除__init__(..., model, ...)參數和self.model = model self.model您應該能夠通過self.model訪問當前模型。 您也在每個批次結束時保存它,這不是您想要的,您可能希望它是on_epoch_end

但無論如何,你正在做的事情可以通過天真的modelcheckpoint回調來完成。 您不需要編寫自定義的。 你可以使用如下;

mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5', 
                                     save_weights_only=True, period=5)
model.fit(X_train, Y_train, callbacks=[mc])

你應該在on_epoch_end上實現而不是實現on_batch_end。 並且將模型作為__init__參數傳遞是多余的。

from keras.callbacks import Callback
class WeightsSaver(Callback):
  def __init__(self, N):
    self.N = N
    self.epoch = 0

  def on_epoch_end(self, epoch, logs={}):
    if self.epoch % self.N == 0:
      name = 'weights%08d.h5' % self.epoch
      self.model.save_weights(name)
    self.epoch += 1

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM