简体   繁体   English

tf.train.Checkpoint 和加载权重

[英]tf.train.Checkpoint and loading weights

I'm training a model for seq2Seq using tensorflow.我正在使用 tensorflow 训练 seq2Seq 模型。 correct me if I'm wrong.如我错了请纠正我。 I understood that the tf.train.Checkpoint is used to save just the checkpoint files which are only useful when source code that will use the saved parameter values is available.我知道 tf.train.Checkpoint 仅用于保存检查点文件,这些文件仅在使用保存的参数值的源代码可用时才有用。 i would like to know how i could instatiate my model later on and load the trained weights from checkpoint in order to test it.我想知道以后如何使我的模型实例化并从检查点加载经过训练的权重以对其进行测试。

checkpoint_dir = 'training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

here is the code for training:这是训练的代码:

EPOCHS = 20

for epoch in range(EPOCHS):
  start = time.time()

  enc_hidden = encoder.initialize_hidden_state()
  total_loss = 0

  for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):
    batch_loss = train_step(inp, targ, enc_hidden)
    total_loss += batch_loss
    if batch % 100 == 0:
      print('Epoch {} Batch {} loss {}'.format(epoch + 1,batch, batch_loss.numpy()))
   
      
  # saving (checkpoint) the model every 2 epochs
  if (epoch + 1) % 2 == 0:
    checkpoint.save(file_prefix = checkpoint_prefix)

regards问候

Here is a proposed answer which suggests to use checkpoint manager.这是一个建议的答案,建议使用检查点管理器。

    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     encoder=encoder,
                                     decoder=decoder)
    manager = tf.train.CheckpointManager(checkpoint, './tf_ckpts', max_to_keep=3)

    def train_and_checkpoint(net, manager)://Net is your custom model here and manager is managing checkpoints
          checkpoint.restore(manager.latest_checkpoint)
          if manager.latest_checkpoint:
            print("Restored from {}".format(manager.latest_checkpoint))
          else:
            print("Initializing from scratch.")
          EPOCHS = 20
        
          for epoch in range(EPOCHS):
           start = time.time()
        
           enc_hidden = encoder.initialize_hidden_state()
           total_loss = 0
        
           for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):
             batch_loss = train_step(inp, targ, enc_hidden)
             total_loss += batch_loss
             if batch % 100 == 0:
               print('Epoch {} Batch {} loss {}'.format(epoch + 1,batch, batch_loss.numpy()))
           
              
          # saving (checkpoint) the model every 2 epochs
           if (epoch + 1) % 2 == 0:
                saved_path = manager.save()
                print("Saved checkpoint for epoch {}: {}".format(int(epoch), save_path))


    //Run the above function once to save the checkpoints once.
    train_and_checkpoint(net, manager)

    //Instantiate a new model and restore the weights , start training again from last checkpoint
    opt = optimizer // the optimizer passed earlier
    net = Net() // your custom model
    
    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                         encoder=encoder,
                                         decoder=decoder)
    manager = tf.train.CheckpointManager(checkpoint, './tf_ckpts', max_to_keep=3)
    
    train_and_checkpoint(net, manager)//it will restore weights from last checkpoint and start training again

Ref - https://www.tensorflow.org/guide/checkpoint#train_and_checkpoint_the_model参考 - https://www.tensorflow.org/guide/checkpoint#train_and_checkpoint_the_model

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM