简体   繁体   English

如何使用tf.train.Checkpoint在Tensorflow 2.0中保存和加载选定变量以及所有变量?

[英]How to save and load selected and all variables in tensorflow 2.0 using tf.train.Checkpoint?

How do I save selected variables in tensorflow 2.0 shown below in a file and load them into some defined variables in another code using tf.train.Checkpoint? 如何将选定的变量保存在下面显示的tensorflow 2.0中的文件中,然后使用tf.train.Checkpoint将它们加载到另一个代码中的某些已定义变量中?

class manyVariables:
    def __init__(self):
        self.initList = [None]*100
        for i in range(100):
            self.initList[i] = tf.Variable(tf.random.normal([5,5]))
        self.makeSomeMoreVariables()

    def makeSomeMoreVariables(self):
        self.moreList = [None]*10
        for i in range(10):
            self.moreList[i] = tf.Variable(tf.random.normal([3,3]))

    def saveVariables(self):
        # how to save self.initList's 3,55 and 60th elements and self.moreList's 4th element

Also, please show how to save all the variables and reload using tf.train.Checkpoint. 另外,请展示如何保存所有变量并使用tf.train.Checkpoint重新加载。 Thanks in advance. 提前致谢。

In the following code I save an array called variables to a .txt file with a name of your choosing. 在下面的代码中,我将一个名为变量的数组保存到具有您选择的名称的.txt文件中。 This file will be in the same folder as your python file. 该文件将与python文件位于同一文件夹中。 The 'wb' in the open function means write with truncation(so removing everything that was previously in the file) and uses byte format. 打开函数中的“ wb”表示以截断形式写入(因此删除文件中以前的所有内容)并使用字节格式。 I use pickle to handle saving/parsing the list. 我用pickle处理保存/解析列表。

import pickle

    def saveVariables(self, variables): #where 'variables' is a list of variables
        with open("nameOfYourFile.txt", 'wb+') as file:
           pickle.dump(variables, file)

    def retrieveVariables(self, filename):
        variables = []
        with open(str(filename), 'rb') as file:
            variables = pickle.load(file)
        return variables

To save specific stuff to your file just add it as the variables argument in saveVariables like so: 要将特定内容保存到文件中,只需将其添加为saveVariables中的variables参数即可,如下所示:

myVariables = [initList[2], initList[54], initList[59], moreList[3]]
saveVariables(myVariables)

To retrieve variables from text file with a certain name: 要从文本文件中检索具有特定名称的变量:

myVariables = retrieveVariables("theNameOfYourFile.txt")
thirdEl = myVariables[0]
fiftyFifthEl = myVariables[1]
SixtiethEl = myVariables[2]
fourthEl = myVariables[3]

You could add these functions anywhere in the class. 您可以在类中的任何位置添加这些功能。

To be able to access the initList/moreList in your example however, you should either return them from their functions(like I do with the variables list) or make them global. 为了能够访问示例中的initList / moreList,您应该从它们的函数中返回它们(就像我对variables列表所做的那样)或将它们设为全局。

I'm not sure if this is what you mean, but you can create a tf.train.Checkpoint object specifically for the variables that you want to save and restore. 我不确定这是否是您的意思,但是您可以专门为要保存和还原的变量创建一个tf.train.Checkpoint对象。 See the following example: 请参见以下示例:

import tensorflow as tf

class manyVariables:
    def __init__(self):
        self.initList = [None]*100
        for i in range(100):
            self.initList[i] = tf.Variable(tf.random.normal([5,5]))
        self.makeSomeMoreVariables()
        self.ckpt = self.makeCheckpoint()

    def makeSomeMoreVariables(self):
        self.moreList = [None]*10
        for i in range(10):
            self.moreList[i] = tf.Variable(tf.random.normal([3,3]))

    def makeCheckpoint(self):
        return tf.train.Checkpoint(
            init3=self.initList[3], init55=self.initList[55],
            init60=self.initList[60], more4=self.moreList[4])

    def saveVariables(self):
        self.ckpt.save('./ckpt')

    def restoreVariables(self):
        status = self.ckpt.restore(tf.train.latest_checkpoint('.'))
        status.assert_consumed()  # Optional check

# Create variables
v1 = manyVariables()
# Assigned fixed values
for i, v in enumerate(v1.initList):
    v.assign(i * tf.ones_like(v))
for i, v in enumerate(v1.moreList):
    v.assign(100 + i * tf.ones_like(v))
# Save them
v1.saveVariables()

# Create new variables
v2 = manyVariables()
# Check initial values
print(v2.initList[2].numpy())
# [[-1.9110833   0.05956204 -1.1753829  -0.3572553  -0.95049495]
#  [ 0.31409055  1.1262076   0.47890127 -0.1699607   0.4409122 ]
#  [-0.75385517 -0.13847834  0.97012395  0.42515194 -1.4371008 ]
#  [ 0.44205236  0.86158335  0.6919655  -2.5156968   0.16496429]
#  [-1.241602   -0.15177743  0.5603795  -0.3560254  -0.18536267]]
print(v2.initList[3].numpy())
# [[-3.3441594  -0.18425298 -0.4898144  -1.2330629   0.08798431]
#  [ 1.5002227   0.99475247  0.7817361   0.3849587  -0.59548247]
#  [-0.57121766 -1.277224    0.6957546  -0.67618763  0.0510064 ]
#  [ 0.85491985  0.13310803 -0.93152267  0.10205163  0.57520276]
#  [-1.0606447  -0.16966362 -1.0448577   0.56799036 -0.90726566]]

# Restore them
v2.restoreVariables()
# Check values after restoring
print(v2.initList[2].numpy())
# [[-1.9110833   0.05956204 -1.1753829  -0.3572553  -0.95049495]
#  [ 0.31409055  1.1262076   0.47890127 -0.1699607   0.4409122 ]
#  [-0.75385517 -0.13847834  0.97012395  0.42515194 -1.4371008 ]
#  [ 0.44205236  0.86158335  0.6919655  -2.5156968   0.16496429]
#  [-1.241602   -0.15177743  0.5603795  -0.3560254  -0.18536267]]
print(v2.initList[3].numpy())
# [[3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]]

If you want to save all the variables in the lists, you could replace makeCheckpoint with something like this: 如果要将所有变量保存在列表中,则可以将makeCheckpoint替换为以下内容:

def makeCheckpoint(self):
    return tf.train.Checkpoint(
        **{f'init{i}': v for i, v in enumerate(self.initList)},
        **{f'more{i}': v for i, v in enumerate(self.moreList)})

Note that you can have "nested" checkpoints so, more generally, you could have a function that makes a checkpoint for a list of variables, for example like this: 请注意,您可以具有“嵌套”检查点,因此,更一般而言,您可以具有为变量列表创建检查点的函数,例如:

def listCheckpoint(varList):
    # Use 'item{}'.format(i) if using Python <3.6
    return tf.train.Checkpoint(**{f'item{i}': v for i, v in enumerate(varList)})

Then you could just have this: 然后您可以拥有:

def makeCheckpoint(self):
    return tf.train.Checkpoint(init=listCheckpoint(self.initList),
                               more=listCheckpoint(self.moreList))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 如何使用tf.train.Checkpoint保存很多变量 - how to save a lot of variables with tf.train.Checkpoint tf.train.Checkpoint 是否正在恢复? - tf.train.Checkpoint is restoring or not? 使用 tf.train.Checkpoint 在 keras 中保存 GAN - Saving a GAN in keras using tf.train.Checkpoint tf.train.Checkpoint 和加载权重 - tf.train.Checkpoint and loading weights tf.train.Checkpoint和tf.train.Saver之间的区别 - Difference between tf.train.Checkpoint and tf.train.Saver tf.keras.callbacks.ModelCheckpoint vs tf.train.Checkpoint - tf.keras.callbacks.ModelCheckpoint vs tf.train.Checkpoint 有什么办法可以将tf.train.Checkpoint与MonitoredTrainingSession一起使用吗? - Is there any way to use tf.train.Checkpoint with MonitoredTrainingSession? Tensorflow Python:无法使用 tf.train.Saver() 恢复变量:错误::未找到:在检查点中未找到关键变量/亚当 - Tensorflow Python: Cannot Restore variables using tf.train.Saver(): Error::Not found: Key Variable/Adam not found in checkpoint 使用功能 API 和 tf.GradientTape() 的组合在 Tensorflow 2.0 中进行训练时,如何将模型图记录到张量板? - How to log model graph to tensorboard when using combination to Functional API and tf.GradientTape() to train in Tensorflow 2.0? 如何在tf.train优化器中创建Checkpoint存储时刻和其他相关变量 - How to make the Checkpoint store moments and other relevant variables in tf.train Optimizers
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM