简体   繁体   English

每 10 个 epoch 保存一次模型 tensorflow.keras v2

[英]Save model every 10 epochs tensorflow.keras v2

I'm using keras defined as submodule in tensorflow v2.我在 tensorflow v2 中使用定义为子模块的 keras。 I'm training my model using fit_generator() method.我正在使用fit_generator()方法训练我的模型。 I want to save my model every 10 epochs.我想每 10 个 epoch 保存一次模型。 How can I achieve this?我怎样才能做到这一点?

In Keras (not as a submodule of tf), I can give ModelCheckpoint(model_savepath,period=10) .在 Keras(不是 tf 的子模块)中,我可以给出ModelCheckpoint(model_savepath,period=10) But in tf v2, they've changed this to ModelCheckpoint(model_savepath, save_freq) where save_freq can be 'epoch' in which case model is saved every epoch.但是在 tf v2 中,他们将其更改为ModelCheckpoint(model_savepath, save_freq)其中save_freq可以是'epoch' ,在这种情况下,模型会在每个纪元保存。 If save_freq is integer, model is saved after so many samples have been processed.如果save_freq是整数,则在处理了这么多样本后保存模型。 But I want it to be after 10 epochs.但我希望它在 10 个时代之后。 How can I achieve this?我怎样才能做到这一点?

Using tf.keras.callbacks.ModelCheckpoint use save_freq='epoch' and pass an extra argument period=10 .使用tf.keras.callbacks.ModelCheckpoint使用save_freq='epoch'并传递一个额外的参数period=10

Although this is not documented in the official docs , that is the way to do it (notice it is documented that you can pass period , just doesn't explain what it does).尽管官方文档中没有记录,但这就是这样做的方法(注意记录表明您可以通过period ,只是没有解释它的作用)。

Explicitly computing the number of batches per epoch worked for me.显式计算每个时期的批次数对我有用。

BATCH_SIZE = 20
STEPS_PER_EPOCH = train_labels.size / BATCH_SIZE
SAVE_PERIOD = 10

# Create a callback that saves the model's weights every 10 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=True,
    save_freq= int(SAVE_PERIOD * STEPS_PER_EPOCH))

# Train the model with the new callback
model.fit(train_images, 
          train_labels,
          batch_size=BATCH_SIZE,
          steps_per_epoch=STEPS_PER_EPOCH,
          epochs=50, 
          callbacks=[cp_callback],
          validation_data=(test_images,test_labels),
          verbose=0)

The param period mentioned in the accepted answer is now not available anymore.已接受答案中提到的参数period现在不再可用。

Using the save_freq param is an alternative, but risky, as mentioned in the docs ;文档中所述,使用save_freq参数是一种替代方法,但有风险; eg, if the dataset size changes, it may become unstable: Note that if the saving isn't aligned to epochs, the monitored metric may potentially be less reliable (again taken from the docs).例如,如果数据集大小发生变化,它可能会变得不稳定:请注意,如果保存与时期不一致,则监控的指标可能不太可靠(再次取自文档)。

Thus, I use a subclass as a solution:因此,我使用子类作为解决方案:

class EpochModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):

    def __init__(self,
                 filepath,
                 frequency=1,
                 monitor='val_loss',
                 verbose=0,
                 save_best_only=False,
                 save_weights_only=False,
                 mode='auto',
                 options=None,
                 **kwargs):
        super(EpochModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only,
                                                   mode, "epoch", options)
        self.epochs_since_last_save = 0
        self.frequency = frequency

    def on_epoch_end(self, epoch, logs=None):
        self.epochs_since_last_save += 1
        # pylint: disable=protected-access
        if self.epochs_since_last_save % self.frequency == 0:
            self._save_model(epoch=epoch, batch=None, logs=logs)

    def on_train_batch_end(self, batch, logs=None):
        pass

use it as用它作为

callbacks=[
     EpochModelCheckpoint("/your_save_location/epoch{epoch:02d}", frequency=10),
]

Note that, dependent on your TF version, you may have to change the args in the call to the superclass __init__ .请注意,根据您的 TF 版本,您可能必须将调用中的 args 更改为超类__init__

I came here looking for this answer too and wanted to point out a couple changes from previous answers.我也来到这里寻找这个答案,并想指出与之前答案的一些变化。 I am using TF version 2.5.0 currently and period= is working but only if there is no save_freq= in the callback.我目前使用的是 TF 版本 2.5.0,并且period=正在工作,但前提是回调中没有save_freq=

my_callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath=path
        period=N
    )
]

This is working for me with no issues even though period is not documented in the callback documentation即使回调文档中没有记录期间,这对我来说也没有问题

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM