简体   繁体   English

在 Tensorflow 2 中为 Model(object) 类创建检查点

[英]Create a checkpoint for class Model(object) in Tensorflow 2

I was just trying to use low level API of tensorflow2.我只是想使用 tensorflow2 的低级 API。 I created my model based on this tutorial : https://www.tensorflow.org/tutorials/customization/custom_training#define_the_model我根据本教程创建了我的模型: https : //www.tensorflow.org/tutorials/customization/custom_training#define_the_model

Then I want to create a checkpoint for my training process, and I follow this tutorial : https://www.tensorflow.org/guide/checkpoint然后我想为我的训练过程创建一个检查点,我按照本教程进行操作: https : //www.tensorflow.org/guide/checkpoint

The problem is the checkpoint's tutorial use a class with tf.keras.Model as parameter, while I use object as my parameter.问题是检查点的教程使用一个以 tf.keras.Model 作为参数的类,而我使用 object 作为我的参数。 It gave me error, said that it was expecting a trackable object.它给了我错误,说它期待一个可追踪的对象。

Here is the snippet of my code:这是我的代码片段:

class SimpleANN(object):
    def __init__(self):
        initializer = tf.initializers.glorot_uniform()
        self.w1 = tf.Variable(initializer([784, 360]), name = 'weight1', trainable = True, dtype = tf.float32)
        self.w2 = tf.Variable(initializer([360, 64]), name = 'weight2', trainable = True, dtype = tf.float32)
        self.w3 = tf.Variable(initializer([64, 10]), name = 'weight3', trainable = True, dtype = tf.float32)

    def __call__(self, x, leaky_relu_alpha = 0.2):
        fc1 = tf.nn.leaky_relu(tf.matmul(x, self.w1), alpha = leaky_relu_alpha)
        fc2 = tf.nn.leaky_relu(tf.matmul(fc1, self.w2), alpha = leaky_relu_alpha)
        logits = tf.matmul(fc2, self.w3)

        return logits

model = SimpleANN() 
optimizer = tf.keras.optimizers.Adam(learning_rate)

ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, model = model)

then I got this error :然后我收到了这个错误:

ValueError: `Checkpoint` was expecting a trackable object (an object
derived from `TrackableBase`), got <__main__.SimpleANN object at
0x000001D859792748>. If you believe this object should be trackable
(i.e. it is part of the TensorFlow Python API and manages state),
please open an issue.

I would like to know If it is able to implement tf.train.Checkpoint for the low level API, as what I was doing.我想知道它是否能够像我所做的那样为低级 API 实现 tf.train.Checkpoint。

What you try is far less, then a tensorflow keras Model does.您尝试的要少得多,然后 tensorflow keras模型就可以了。 If you really want to know, what does "low level" model creation mean, take a visit to github: tensorflow/tensorflow/python/keras/engine/training.py , where you will find a lot reference to of what you need to do to reach success with your approach eg @trackable.no_automatic_dependency_tracking decorator.如果你真的想知道,“低级”模型创建是什么意思,请访问 github: tensorflow/tensorflow/python/keras/engine/training.py ,在那里你会找到很多你需要的参考使用您的方法取得成功,例如@trackable.no_automatic_dependency_tracking装饰器。 As you will see there, even Model() class has parameters, you have to dive into.正如您将在那里看到的,即使是Model()类也有参数,您必须深入研究。 Obviously it is not impossible, but you have to dive deeper.显然这不是不可能的,但你必须深入研究。

The TensorFlow SavedModel API can only be used to save trackable objects (which Keras models are by default). TensorFlow SavedModel API 只能用于保存可跟踪对象(Keras 模型默认是这些对象)。 One way to create a trackable object using the low-level API is to inherit from tf.Module .使用低级 API 创建可跟踪对象的一种方法是从tf.Module继承。 In my environment (Python version 3.7.6, TensorFlow version 2.1.0), I can make the errors go away by replacing the line class SimpleANN(object): with class SimpleANN(tf.Module): .在我的环境(Python 版本 3.7.6,TensorFlow 版本 2.1.0)中,我可以通过将行class SimpleANN(object):替换为class SimpleANN(tf.Module):

You may or may not also need to decorate your methods with @tf.function , and call them once to trace-compile a graph before saving a checkpoint.您可能也可能不需要用@tf.function装饰您的方法,并在保存检查点之前调用它们一次以跟踪编译图形。 For more information, see my answer to RobR's question here .有关更多信息,请参阅我对 RobR 问题的回答

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 张量流对象检测从现有检查点微调模型 - tensorflow object detection Fine-tuning a model from an existing checkpoint 加载张量流检查点作为keras模型 - Load tensorflow checkpoint as keras model 如何在Tensorflow中为当前模型恢复预训练的检查点? - How to restore pretrained checkpoint for current model in Tensorflow? 使用Tensorflow检查点在C ++中恢复模型 - Using Tensorflow checkpoint to restore model in C++ TensorFlow:有没有办法将冻结图转换为检查点模型? - TensorFlow: Is there a way to convert a frozen graph into a checkpoint model? 从张量流检查点加载特定模型时出错 - Error in loading particular model from tensorflow checkpoint 从先前的检查点还原Tensorflow模型 - Restoring Tensorflow model from a previous checkpoint 是否可以从 Tensorflow 中的检查点 model 恢复训练? - Is it possible to resume training from a checkpoint model in Tensorflow? 无法加载预训练的 model 检查点与 TensorFlow Object 检测 ZDB974238714CA8DE634A7ACE1 - Unable to load pre-trained model checkpoint with TensorFlow Object Detection API 了解 Tensorflow 对象检测 API,检查点的 kwargs class,什么是`_base_tower_layers_for_heads`? - Understanding Tensorflow Object-Detection API, kwargs for Checkpoint class, what is `_base_tower_layers_for_heads`?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM