[英]Save model every 10 epochs tensorflow.keras v2
I'm using keras defined as submodule in tensorflow v2.我在 tensorflow v2 中使用定义为子模块的 keras。 I'm training my model using fit_generator()
method.我正在使用fit_generator()
方法训练我的模型。 I want to save my model every 10 epochs.我想每 10 个 epoch 保存一次模型。 How can I achieve this?我怎样才能做到这一点?
In Keras (not as a submodule of tf), I can give ModelCheckpoint(model_savepath,period=10)
.在 Keras(不是 tf 的子模块)中,我可以给出ModelCheckpoint(model_savepath,period=10)
。 But in tf v2, they've changed this to ModelCheckpoint(model_savepath, save_freq)
where save_freq
can be 'epoch'
in which case model is saved every epoch.但是在 tf v2 中,他们将其更改为ModelCheckpoint(model_savepath, save_freq)
其中save_freq
可以是'epoch'
,在这种情况下,模型会在每个纪元保存。 If save_freq
is integer, model is saved after so many samples have been processed.如果save_freq
是整数,则在处理了这么多样本后保存模型。 But I want it to be after 10 epochs.但我希望它在 10 个时代之后。 How can I achieve this?我怎样才能做到这一点?
Using tf.keras.callbacks.ModelCheckpoint
use save_freq='epoch'
and pass an extra argument period=10
.使用tf.keras.callbacks.ModelCheckpoint
使用save_freq='epoch'
并传递一个额外的参数period=10
。
Although this is not documented in the official docs , that is the way to do it (notice it is documented that you can pass period
, just doesn't explain what it does).尽管官方文档中没有记录,但这就是这样做的方法(注意记录表明您可以通过period
,只是没有解释它的作用)。
Explicitly computing the number of batches per epoch worked for me.显式计算每个时期的批次数对我有用。
BATCH_SIZE = 20
STEPS_PER_EPOCH = train_labels.size / BATCH_SIZE
SAVE_PERIOD = 10
# Create a callback that saves the model's weights every 10 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
save_freq= int(SAVE_PERIOD * STEPS_PER_EPOCH))
# Train the model with the new callback
model.fit(train_images,
train_labels,
batch_size=BATCH_SIZE,
steps_per_epoch=STEPS_PER_EPOCH,
epochs=50,
callbacks=[cp_callback],
validation_data=(test_images,test_labels),
verbose=0)
The param period
mentioned in the accepted answer is now not available anymore.已接受答案中提到的参数period
现在不再可用。
Using the save_freq
param is an alternative, but risky, as mentioned in the docs ;如文档中所述,使用save_freq
参数是一种替代方法,但有风险; eg, if the dataset size changes, it may become unstable: Note that if the saving isn't aligned to epochs, the monitored metric may potentially be less reliable (again taken from the docs).例如,如果数据集大小发生变化,它可能会变得不稳定:请注意,如果保存与时期不一致,则监控的指标可能不太可靠(再次取自文档)。
Thus, I use a subclass as a solution:因此,我使用子类作为解决方案:
class EpochModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
def __init__(self,
filepath,
frequency=1,
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
options=None,
**kwargs):
super(EpochModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only,
mode, "epoch", options)
self.epochs_since_last_save = 0
self.frequency = frequency
def on_epoch_end(self, epoch, logs=None):
self.epochs_since_last_save += 1
# pylint: disable=protected-access
if self.epochs_since_last_save % self.frequency == 0:
self._save_model(epoch=epoch, batch=None, logs=logs)
def on_train_batch_end(self, batch, logs=None):
pass
use it as用它作为
callbacks=[
EpochModelCheckpoint("/your_save_location/epoch{epoch:02d}", frequency=10),
]
Note that, dependent on your TF version, you may have to change the args in the call to the superclass __init__
.请注意,根据您的 TF 版本,您可能必须将调用中的 args 更改为超类__init__
。
I came here looking for this answer too and wanted to point out a couple changes from previous answers.我也来到这里寻找这个答案,并想指出与之前答案的一些变化。 I am using TF version 2.5.0 currently and period=
is working but only if there is no save_freq=
in the callback.我目前使用的是 TF 版本 2.5.0,并且period=
正在工作,但前提是回调中没有save_freq=
。
my_callbacks = [
keras.callbacks.ModelCheckpoint(
filepath=path
period=N
)
]
This is working for me with no issues even though period is not documented in the callback documentation即使回调文档中没有记录期间,这对我来说也没有问题
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.