简体   繁体   English

Tensorflow Keras无法在初始时从检查点文件正确恢复训练

[英]Tensorflow Keras cannot properly resume training at initial epoch from checkpoint file

I am loading a keras model in tensorflow to resume training. 我正在tensorflow中加载一个keras模型以恢复训练。 I want to continue training from the epoch I stopped at so that epoch numbers are unique and I can keep track of the number of epochs. 我想从停下来的纪元开始继续训练,以便纪元号是唯一的,并且我可以跟踪纪元数。 The model is loaded from a checkpoint file created by a callback that saves the highest accuracy. 从保存了最高准确性的回调创建的检查点文件中加载模型。 When I resume training in model.fit(), I set the "initial epoch" to be 52 and set "epoch" to 52+5. 当我恢复对model.fit()的训练时,我将“初始纪元”设置为52,并将“纪元”设置为52 + 5。 However, it starts training from epoch 1/57 instead of 53/57 and will keep going up to 57 even though I only want 5 epochs. 但是,它从1/57而不是53/57开始训练,即使我只想5个epoch也将一直上升到57。 Am I loading something wrongly? 我加载错误吗? Training resumes as 'normal' and accuracy is where I left off, but the epoch numbers do not continue from where I want, and keep restarting from 1. 训练恢复为“正常”状态,准确性是我中断的地方,但是时期数不会从我想要的地方继续,而是从1开始重新开始。

I have tried removing the checkpoint callback initialisation when loading form the checkpoint file, but that generates a name error as the "callbacks list" is not defined. 我已经尝试从检查点文件加载时删除检查点回调初始化,但是由于未定义“回调列表”,因此会产生名称错误。

model = load_model('my_model.hdf5')
checkpoint = ModelCheckpoint(cp_filepath, monitor='acc', 
verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]

bs=32 #batch size
epoch count=52
cur_epochs=5
model.fit(
    training_set,
    steps_per_epoch=len(training_set)//bs,
    inital_epoch=epoch_count,
    epochs=cur_epochs+epoch_count,
    validation_data=test_set,
    validation_steps=len(test_set)//bs,
    callbacks=callbacks_list, 
    shuffle=True,
    verbose=1
    )

I expect to see epoch 53/57 and 5 epochs of training when resuming from a saved file. 从保存的文件恢复时,我希望能看到53/57和5个训练时期。 I get epoch 1/57 and 57 epochs of training 我得到了1/57和57个训练纪元

I noticed that you forgot to put an underscore in epoch_count. 我注意到您忘记在epoch_count中添加下划线。 That might be what is causing it. 这可能是造成它的原因。

Had the same issue, To solve it I modified the ModelCheckpoint (Callback) class. 遇到相同的问题,为解决此问题,我修改了ModelCheckpoint (回调)类。

I added and saved a dedicated tensorflow checkpoint for epoch, in the on_epoch_begin callback function. 我在on_epoch_begin回调函数中添加并保存了一个专用的tensorflow检查点。

The network doesn't store its training progress with respect to training data - this is not part of its state, because at any point you could decide to change what data set to feed it. 网络不会存储有关训练数据的训练进度-这不是其状态的一部分,因为您随时可以决定更改要提供的数据集。

class EpochModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):

    def __init__(self,filepath, monitor='val_loss', verbose=1, 
                 save_best_only=True, save_weights_only=True, 
                 mode='auto', ):

        super(EpochModelCheckpoint, self).__init__(filepath=filepath,monitor=monitor,
             verbose=verbose,save_best_only=save_best_only,
             save_weights_only=save_weights_only, mode=mode)

        self.ckpt = tf.train.Checkpoint(completed_epochs=tf.Variable(0,trainable=False,dtype='int32'))
        ckpt_dir = f'{os.path.dirname(filepath)}/tf_ckpts'
        self.manager = tf.train.CheckpointManager(self.ckpt, ckpt_dir, max_to_keep=3)

    def on_epoch_begin(self,epoch,logs=None):        
        self.ckpt.completed_epochs.assign(epoch)
        self.manager.save()
        print( f"Epoch checkpoint {self.ckpt.completed_epochs.numpy()}  saved to: {self.manager.latest_checkpoint}" ) 
        print(logs)

def callbacks(checkpoint_dir, model_name):

    best_model = os.path.join(checkpoint_dir, '{}_best.hdf5'.format(model_name))
    save_best = EpochModelCheckpoint( best_model  )
    return [ save_best ]

def train():

    ...

    model = get_compiled_model()
    checkpoint_dir = "./checkpoint_dir"
    model_name = "my_model"
    # Init checkpoint value
    ckpt = tf.train.Checkpoint(completed_epochs=tf.Variable(0,trainable=False,dtype='int32'))
    manager = tf.train.CheckpointManager(ckpt, f'{checkpoint_dir}/tf_ckpts', max_to_keep=3)    

    best_weights = os.path.join(checkpoint_dir, f'{model_name}_best.hdf5') 
    if os.path.exists(best_weights):
        print(f'Loading model {best_weights}')
        model.load_weights(best_weights)

        # Restore last Epoch
        ckpt.restore(manager.latest_checkpoint)
        if manager.latest_checkpoint:
            print(f"Restored epoch ckpt from {manager.latest_checkpoint}, value is ",ckpt.completed_epochs.numpy())
        else:
            print("Initializing from scratch.")

     ...
    # Set initial_epoch in the model fit to last seen Epoch
    completed_epochs=ckpt.completed_epochs.numpy()
    history = model.fit(
        x=train_ds,
        epochs=cfg.epochs,
        steps_per_epoch=cfg.steps,
        callbacks=callbacks( checkpoint_dir, model_name ),        
        validation_data=val_ds,
        validation_steps=cfg.val_steps,
        initial_epoch=completed_epochs )

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 如何从张量流检查点文件正确恢复网络训练? - How to resume properly the training of a network from a tensorflow checkpoint file? 是否可以从 Tensorflow 中的检查点 model 恢复训练? - Is it possible to resume training from a checkpoint model in Tensorflow? TensorFlow/Keras:如何使用 model.checkpoint() 恢复训练? - TensorFlow/Keras: How to resume training using model.checkpoint()? 为什么我的Keras训练不能正常恢复? - Why does my Keras training not resume properly? 我正在尝试从某个检查点(Tensorflow)恢复训练,因为我使用的是 Colab,而 12 小时还不够 - I am trying to resume training from a certain checkpoint (Tensorflow) because I'm using Colab and 12 hours aren't enough Huggingface Transformer - GPT2 从保存的检查点恢复训练 - Huggingface Transformer - GPT2 resume training from saved checkpoint 将 Horovod 与 tf.keras 一起使用时如何从检查点恢复? - How to resume from a checkpoint when using Horovod with tf.keras? 如何使用 Tensorflow 2/Keras 保存和继续训练具有多个模型部分的 GAN - How to save and resume training a GAN with multiple model parts with Tensorflow 2/ Keras 如何在张量流中从* .meta恢复训练? - How to resume training from *.meta in tensorflow? Tensorflow 停止并恢复训练 - Tensorflow stop and resume training
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM