[英]tf.keras.callbacks.ModelCheckpoint vs tf.train.Checkpoint
I am kinda new to TensorFlow world but have written some programs in Keras.我对 TensorFlow 世界有点陌生,但在 Keras 中编写了一些程序。 Since TensorFlow 2 is officially similar to Keras, I am quite confused about what is the difference between tf.keras.callbacks.ModelCheckpoint and tf.train.Checkpoint.
由于 TensorFlow 2 官方与 Keras 类似,我很困惑 tf.keras.callbacks.ModelCheckpoint 和 tf.train.Checkpoint 有什么区别If anybody can shed light on this, I would appreciate it.
如果有人能阐明这一点,我将不胜感激。
It depends on whether a custom training loop is required.这取决于是否需要自定义训练循环。 In most cases, it's not and you can just call
model.fit()
and pass tf.keras.callbacks.ModelCheckpoint
.在大多数情况下,它不是,您只需调用
model.fit()
并传递tf.keras.callbacks.ModelCheckpoint
。 If you do need to write your custom training loop, then you have to use tf.train.Checkpoint
(and tf.train.CheckpointManager
) since there's no callback mechanism.如果您确实需要编写自定义训练循环,那么您必须使用
tf.train.Checkpoint
(和tf.train.CheckpointManager
),因为没有回调机制。
I also had a hard time differentiating between the checkpoint objects used when I looked at other people's code, so I wrote down some notes about when to use which one and how to use them in general.当我查看其他人的代码时,我也很难区分使用的检查点对象,所以我写了一些关于何时使用哪个以及一般如何使用它们的注释。 Either-way, I think it might be useful for other people having the same issue:
无论哪种方式,我认为它可能对其他有同样问题的人有用:
These are 2 ways of saving your model's checkpoints, each is for a different use case:这是保存模型检查点的两种方法,每种方法都针对不同的用例:
This is use-full when you are managing the training loops yourself.当您自己管理训练循环时,这很有用。
You use them like this:你像这样使用它们:
Definition from the docs : " A Checkpoint object can be constructed to save either a single or group of trackable objects to a checkpoint file ". 文档中的定义:“可以构造检查点 object 以将单个或一组可跟踪对象保存到检查点文件”。
How to initialise it:如何初始化它:
ckpt = Checkpoint(discr_opt=discr_opt, genrt_opt=genrt_opt, wgan = wgan, d_model = d_model, g_model = g_model)
1.2) CheckpointManager
This literally manages the checkpoints you have defined to be stored at a location and things like how many to to keep.这实际上管理了您定义的要存储在某个位置的检查点以及要保留的数量。 Definition from the docs : " Manages multiple checkpoints by keeping some and deleting unneeded ones "
文档中的定义:“通过保留一些检查点并删除不需要的检查点来管理多个检查点”
How to initialise it:如何初始化它:
manager = CheckpointManager(ckpt, "training_checkpoints_wgan", max_to_keep=3)
How to use it:如何使用它:
manager.save()
You want to use this callback when you are not managing epoch iterations yourself.当您不自己管理 epoch 迭代时,您希望使用此回调。 For example when you have setup a relatively simple Sequential model and you call model.fit(), which manages the training process for you.
例如,当您设置了一个相对简单的 Sequential model 并调用 model.fit() 时,它会为您管理训练过程。
Definition from the docs : " Callback to save the Keras model or model weights at some frequency. " 文档中的定义:“回调以在某个频率保存 Keras model 或 model 权重。 ”
How to initialise it:如何初始化它:
Pass in the path where to save the model传入保存model的路径
The option save_weights_only is set to False by default:默认情况下,选项save_weights_only设置为 False:
The option save_best_only is set to False by default:默认情况下,选项save_best_only设置为 False:
verbose is set to 0 (False), so you can update this to 1 to validate it详细设置为 0(假),因此您可以将其更新为 1 以验证它
mc = ModelCheckpoint("training_checkpoints/cp.ckpt", save_best_only=True, save_weights_only=False)
How to use it:如何使用它:
model.fit(X, y, epochs=100, callbacks=[mc])
TensorFlow is a 'computation' library and Keras is a Deep Learning library which can work with TF or PyTorch, etc. So what TF provides is a more generic not-so-customized-for-deep-learning version. TensorFlow 是一个“计算”库,Keras 是一个深度学习库,可以与 TF 或 PyTorch 等一起使用。所以 TF 提供的是一个更通用的非深度学习版本。 If you just compare the docs you can see how more comprehensive and customized
ModelCheckpoint
is.如果您只是比较文档,您会发现
ModelCheckpoint
更加全面和定制。 Checkpoint just reads and writes stuff from/to disk.检查点只是从/向磁盘读取和写入内容。
ModelCheckpoint
is much smarter! ModelCheckpoint
更智能!
Also, ModelCheckpoint
is a callback.此外,
ModelCheckpoint
是一个回调。 It means you can just make an instance of it and pass it to the fit
function:这意味着您可以只创建一个实例并将其传递给
fit
function:
model_checkpoint = ModelCheckpoint(...)
model.fit(..., callbacks=[..., model_checkpoint, ...], ...)
I took a quick look at Keras's implementation of ModelCheckpoint
, it calls either save
or save_weights
method on Model
which in some cases uses TensorFlow's CheckPoint
itself.我快速浏览了 Keras 的
ModelCheckpoint
实现,它在Model
上调用save
或save_weights
方法,在某些情况下使用 TensorFlow 的CheckPoint
本身。 So it is not a wrapper per se but certainly is on a lower level of abstraction -- more specialized for saving Keras models.所以它本身不是一个包装器,但肯定是在较低的抽象层次上——更专门用于保存 Keras 模型。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.