[英]Save model every 10 epochs tensorflow.keras v2
我在 tensorflow v2 中使用定義為子模塊的 keras。 我正在使用fit_generator()
方法訓練我的模型。 我想每 10 個 epoch 保存一次模型。 我怎樣才能做到這一點?
在 Keras(不是 tf 的子模塊)中,我可以給出ModelCheckpoint(model_savepath,period=10)
。 但是在 tf v2 中,他們將其更改為ModelCheckpoint(model_savepath, save_freq)
其中save_freq
可以是'epoch'
,在這種情況下,模型會在每個紀元保存。 如果save_freq
是整數,則在處理了這么多樣本后保存模型。 但我希望它在 10 個時代之后。 我怎樣才能做到這一點?
使用tf.keras.callbacks.ModelCheckpoint
使用save_freq='epoch'
並傳遞一個額外的參數period=10
。
盡管官方文檔中沒有記錄,但這就是這樣做的方法(注意記錄表明您可以通過period
,只是沒有解釋它的作用)。
顯式計算每個時期的批次數對我有用。
BATCH_SIZE = 20
STEPS_PER_EPOCH = train_labels.size / BATCH_SIZE
SAVE_PERIOD = 10
# Create a callback that saves the model's weights every 10 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
save_freq= int(SAVE_PERIOD * STEPS_PER_EPOCH))
# Train the model with the new callback
model.fit(train_images,
train_labels,
batch_size=BATCH_SIZE,
steps_per_epoch=STEPS_PER_EPOCH,
epochs=50,
callbacks=[cp_callback],
validation_data=(test_images,test_labels),
verbose=0)
已接受答案中提到的參數period
現在不再可用。
如文檔中所述,使用save_freq
參數是一種替代方法,但有風險; 例如,如果數據集大小發生變化,它可能會變得不穩定:請注意,如果保存與時期不一致,則監控的指標可能不太可靠(再次取自文檔)。
因此,我使用子類作為解決方案:
class EpochModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
def __init__(self,
filepath,
frequency=1,
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
options=None,
**kwargs):
super(EpochModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only,
mode, "epoch", options)
self.epochs_since_last_save = 0
self.frequency = frequency
def on_epoch_end(self, epoch, logs=None):
self.epochs_since_last_save += 1
# pylint: disable=protected-access
if self.epochs_since_last_save % self.frequency == 0:
self._save_model(epoch=epoch, batch=None, logs=logs)
def on_train_batch_end(self, batch, logs=None):
pass
用它作為
callbacks=[
EpochModelCheckpoint("/your_save_location/epoch{epoch:02d}", frequency=10),
]
請注意,根據您的 TF 版本,您可能必須將調用中的 args 更改為超類__init__
。
我也來到這里尋找這個答案,並想指出與之前答案的一些變化。 我目前使用的是 TF 版本 2.5.0,並且period=
正在工作,但前提是回調中沒有save_freq=
。
my_callbacks = [
keras.callbacks.ModelCheckpoint(
filepath=path
period=N
)
]
即使回調文檔中沒有記錄期間,這對我來說也沒有問題
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.