简体   繁体   English

tf.keras.callbacks.ModelCheckpoint vs tf.train.Checkpoint

[英]tf.keras.callbacks.ModelCheckpoint vs tf.train.Checkpoint

I am kinda new to TensorFlow world but have written some programs in Keras.我对 TensorFlow 世界有点陌生,但在 Keras 中编写了一些程序。 Since TensorFlow 2 is officially similar to Keras, I am quite confused about what is the difference between tf.keras.callbacks.ModelCheckpoint and tf.train.Checkpoint.由于 TensorFlow 2 官方与 Keras 类似,我很困惑 tf.keras.callbacks.ModelCheckpoint 和 tf.train.Checkpoint 有什么区别If anybody can shed light on this, I would appreciate it.如果有人能阐明这一点,我将不胜感激。

It depends on whether a custom training loop is required.这取决于是否需要自定义训练循环。 In most cases, it's not and you can just call model.fit() and pass tf.keras.callbacks.ModelCheckpoint .在大多数情况下,它不是,您只需调用model.fit()并传递tf.keras.callbacks.ModelCheckpoint If you do need to write your custom training loop, then you have to use tf.train.Checkpoint (and tf.train.CheckpointManager ) since there's no callback mechanism.如果您确实需要编写自定义训练循环,那么您必须使用tf.train.Checkpoint (和tf.train.CheckpointManager ),因为没有回调机制。

I also had a hard time differentiating between the checkpoint objects used when I looked at other people's code, so I wrote down some notes about when to use which one and how to use them in general.当我查看其他人的代码时,我也很难区分使用的检查点对象,所以我写了一些关于何时使用哪个以及一般如何使用它们的注释。 Either-way, I think it might be useful for other people having the same issue:无论哪种方式,我认为它可能对其他有同样问题的人有用:

Saving model Checkpoints保存 model 检查点

These are 2 ways of saving your model's checkpoints, each is for a different use case:这是保存模型检查点的两种方法,每种方法都针对不同的用例:

1) Checkpoint & CheckpointManager 1) 检查点和检查点管理器

This is use-full when you are managing the training loops yourself.当您自己管理训练循环时,这很有用。

You use them like this:你像这样使用它们:

1.1) Checkpoint 1.1) 检查点

Definition from the docs : " A Checkpoint object can be constructed to save either a single or group of trackable objects to a checkpoint file ". 文档中的定义:“可以构造检查点 object 以将单个或一组可跟踪对象保存到检查点文件”。

How to initialise it:如何初始化它:

  • You can pass it key value pairs for:您可以将键值对传递给:
    • All the custom function calls or objects that make up your model and you want to keep track of:所有自定义 function 调用或构成您的 model 的对象,并且您想要跟踪:
    • Like a generator, discriminiator, loss function, optimizer etc像生成器、判别器、损失 function、优化器等
ckpt = Checkpoint(discr_opt=discr_opt, genrt_opt=genrt_opt, wgan = wgan, d_model = d_model, g_model = g_model)
1.2) CheckpointManager 1.2) 检查点管理器

This literally manages the checkpoints you have defined to be stored at a location and things like how many to to keep.这实际上管理了您定义的要存储在某个位置的检查点以及要保留的数量。 Definition from the docs : " Manages multiple checkpoints by keeping some and deleting unneeded ones " 文档中的定义:“通过保留一些检查点并删除不需要的检查点来管理多个检查点

How to initialise it:如何初始化它:

  • Initialise it with the CheckPoint object you create as first argument.使用您作为第一个参数创建的 CheckPoint object 对其进行初始化。
  • The directory where to save the checkpoint files.保存检查点文件的目录。
  • And you probably want to define how much you want to keep, since this can be a lot of complex models你可能想定义你想保留多少,因为这可能是很多复杂的模型
manager = CheckpointManager(ckpt, "training_checkpoints_wgan", max_to_keep=3)

How to use it:如何使用它:

  • We have setup the manager object with our specified checkpoints, so it's ready to use.我们已经使用我们指定的检查点设置了管理器 object,因此它可以使用了。
  • Call this at the end of each training epoch在每个训练时期结束时调用它
manager.save()

2) ModelCheckpoint (callback) 2)模型检查点(回调)

You want to use this callback when you are not managing epoch iterations yourself.当您不自己管理 epoch 迭代时,您希望使用此回调 For example when you have setup a relatively simple Sequential model and you call model.fit(), which manages the training process for you.例如,当您设置了一个相对简单的 Sequential model 并调用 model.fit() 时,它会为您管理训练过程。

Definition from the docs : " Callback to save the Keras model or model weights at some frequency. " 文档中的定义:“回调以在某个频率保存 Keras model 或 model 权重。

How to initialise it:如何初始化它:

  • Pass in the path where to save the model传入保存model的路径

  • The option save_weights_only is set to False by default:默认情况下,选项save_weights_only设置为 False:

    • If you want to only save the weights make sure to update this如果您只想保存权重,请确保更新此
  • The option save_best_only is set to False by default:默认情况下,选项save_best_only设置为 False:

    • If you want to only save the best model instead of all of them, you can set this to True.如果您只想保存最好的 model 而不是全部,您可以将其设置为 True。
  • verbose is set to 0 (False), so you can update this to 1 to validate it详细设置为 0(假),因此您可以将其更新为 1 以验证它

mc = ModelCheckpoint("training_checkpoints/cp.ckpt", save_best_only=True, save_weights_only=False)

How to use it:如何使用它:

  • The model checkpoint callback is now ready to for training. model 检查点回调现已准备好进行训练。
  • You pass in the object in you into your callbacks list when you fit the model:当您适合 model 时,您将 object 传递到您的回调列表中:
 model.fit(X, y, epochs=100, callbacks=[mc])

TensorFlow is a 'computation' library and Keras is a Deep Learning library which can work with TF or PyTorch, etc. So what TF provides is a more generic not-so-customized-for-deep-learning version. TensorFlow 是一个“计算”库,Keras 是一个深度学习库,可以与 TF 或 PyTorch 等一起使用。所以 TF 提供的是一个更通用的非深度学习版本。 If you just compare the docs you can see how more comprehensive and customized ModelCheckpoint is.如果您只是比较文档,您会发现ModelCheckpoint更加全面和定制。 Checkpoint just reads and writes stuff from/to disk.检查点只是从/向磁盘读取和写入内容。 ModelCheckpoint is much smarter! ModelCheckpoint更智能!

Also, ModelCheckpoint is a callback.此外, ModelCheckpoint是一个回调。 It means you can just make an instance of it and pass it to the fit function:这意味着您可以只创建一个实例并将其传递给fit function:

model_checkpoint = ModelCheckpoint(...)
model.fit(..., callbacks=[..., model_checkpoint, ...], ...)

I took a quick look at Keras's implementation of ModelCheckpoint , it calls either save or save_weights method on Model which in some cases uses TensorFlow's CheckPoint itself.我快速浏览了 Keras 的ModelCheckpoint实现,它在Model上调用savesave_weights方法,在某些情况下使用 TensorFlow 的CheckPoint本身。 So it is not a wrapper per se but certainly is on a lower level of abstraction -- more specialized for saving Keras models.所以它本身不是一个包装器,但肯定是在较低的抽象层次上——更专门用于保存 Keras 模型。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM