簡體   English   中英

每 10 個 epoch 保存一次模型 tensorflow.keras v2

[英]Save model every 10 epochs tensorflow.keras v2

我在 tensorflow v2 中使用定義為子模塊的 keras。 我正在使用fit_generator()方法訓練我的模型。 我想每 10 個 epoch 保存一次模型。 我怎樣才能做到這一點?

在 Keras(不是 tf 的子模塊)中,我可以給出ModelCheckpoint(model_savepath,period=10) 但是在 tf v2 中,他們將其更改為ModelCheckpoint(model_savepath, save_freq)其中save_freq可以是'epoch' ,在這種情況下,模型會在每個紀元保存。 如果save_freq是整數,則在處理了這么多樣本后保存模型。 但我希望它在 10 個時代之后。 我怎樣才能做到這一點?

使用tf.keras.callbacks.ModelCheckpoint使用save_freq='epoch'並傳遞一個額外的參數period=10

盡管官方文檔中沒有記錄,但這就是這樣做的方法(注意記錄表明您可以通過period ,只是沒有解釋它的作用)。

顯式計算每個時期的批次數對我有用。

BATCH_SIZE = 20
STEPS_PER_EPOCH = train_labels.size / BATCH_SIZE
SAVE_PERIOD = 10

# Create a callback that saves the model's weights every 10 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=True,
    save_freq= int(SAVE_PERIOD * STEPS_PER_EPOCH))

# Train the model with the new callback
model.fit(train_images, 
          train_labels,
          batch_size=BATCH_SIZE,
          steps_per_epoch=STEPS_PER_EPOCH,
          epochs=50, 
          callbacks=[cp_callback],
          validation_data=(test_images,test_labels),
          verbose=0)

已接受答案中提到的參數period現在不再可用。

文檔中所述,使用save_freq參數是一種替代方法,但有風險; 例如,如果數據集大小發生變化,它可能會變得不穩定:請注意,如果保存與時期不一致,則監控的指標可能不太可靠(再次取自文檔)。

因此,我使用子類作為解決方案:

class EpochModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):

    def __init__(self,
                 filepath,
                 frequency=1,
                 monitor='val_loss',
                 verbose=0,
                 save_best_only=False,
                 save_weights_only=False,
                 mode='auto',
                 options=None,
                 **kwargs):
        super(EpochModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only,
                                                   mode, "epoch", options)
        self.epochs_since_last_save = 0
        self.frequency = frequency

    def on_epoch_end(self, epoch, logs=None):
        self.epochs_since_last_save += 1
        # pylint: disable=protected-access
        if self.epochs_since_last_save % self.frequency == 0:
            self._save_model(epoch=epoch, batch=None, logs=logs)

    def on_train_batch_end(self, batch, logs=None):
        pass

用它作為

callbacks=[
     EpochModelCheckpoint("/your_save_location/epoch{epoch:02d}", frequency=10),
]

請注意,根據您的 TF 版本,您可能必須將調用中的 args 更改為超類__init__

我也來到這里尋找這個答案,並想指出與之前答案的一些變化。 我目前使用的是 TF 版本 2.5.0,並且period=正在工作,但前提是回調中沒有save_freq=

my_callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath=path
        period=N
    )
]

即使回調文檔中沒有記錄期間,這對我來說也沒有問題

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM